519 lines
17 KiB
C#
519 lines
17 KiB
C#
using System;
|
|
using System.Buffers;
|
|
using System.Collections.Generic;
|
|
using System.Runtime.InteropServices;
|
|
using System.Text.Json;
|
|
using LLama.Native;
|
|
|
|
namespace LLama.Batched;
|
|
|
|
/// <summary>
|
|
/// A single conversation thread that can be prompted (adding tokens from the user) or inferred (extracting a token from the LLM)
|
|
/// </summary>
|
|
public sealed class Conversation
|
|
: IDisposable
|
|
{
|
|
private ulong _requiredEpoch;
|
|
private LLamaPos _end;
|
|
private int _batchSampleIndex;
|
|
private bool _disposed;
|
|
private bool _forked;
|
|
|
|
/// <summary>
|
|
/// The executor which this conversation belongs to
|
|
/// </summary>
|
|
public BatchedExecutor Executor { get; }
|
|
|
|
/// <summary>
|
|
/// Unique ID for this conversation
|
|
/// </summary>
|
|
public LLamaSeqId ConversationId { get; }
|
|
|
|
/// <summary>
|
|
/// Total number of tokens in this conversation, cannot exceed the context length.
|
|
/// </summary>
|
|
public int TokenCount => _end.Value;
|
|
|
|
/// <summary>
|
|
/// Indicates if this conversation has been disposed, nothing can be done with a disposed conversation
|
|
/// </summary>
|
|
public bool IsDisposed => _disposed || Executor.IsDisposed;
|
|
|
|
/// <summary>
|
|
/// Indicates if this conversation is waiting for inference to be run on the executor. "Prompt" and "Sample" cannot be called when this is true.
|
|
/// </summary>
|
|
public bool RequiresInference => _requiredEpoch > Executor.Epoch;
|
|
|
|
/// <summary>
|
|
/// Indicates that this conversation should be sampled.
|
|
/// </summary>
|
|
public bool RequiresSampling => _requiredEpoch == Executor.Epoch;
|
|
|
|
#region construction/destruction
|
|
internal Conversation(BatchedExecutor batch, LLamaSeqId id)
|
|
{
|
|
ConversationId = id;
|
|
Executor = batch;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Finalizer for Conversation
|
|
/// </summary>
|
|
~Conversation()
|
|
{
|
|
Dispose();
|
|
}
|
|
|
|
/// <summary>
|
|
/// End this conversation, freeing all resources used by it
|
|
/// </summary>
|
|
/// <exception cref="ObjectDisposedException"></exception>
|
|
public void Dispose()
|
|
{
|
|
if (IsDisposed)
|
|
return;
|
|
_disposed = true;
|
|
|
|
// Remove this conversation from the KV cache
|
|
Executor.Context.NativeHandle.KvCacheRemove(ConversationId, 0, _end);
|
|
|
|
// Prevent finalizer from running
|
|
GC.SuppressFinalize(this);
|
|
}
|
|
|
|
private void AssertNotDisposed()
|
|
{
|
|
if (Executor.IsDisposed)
|
|
throw new ObjectDisposedException(nameof(BatchedExecutor));
|
|
if (IsDisposed)
|
|
throw new ObjectDisposedException(nameof(Conversation));
|
|
}
|
|
#endregion
|
|
|
|
/// <summary>
|
|
/// Create a copy of the current conversation
|
|
/// </summary>
|
|
/// <remarks>The copy shares internal state, so consumes very little extra memory.</remarks>
|
|
/// <returns></returns>
|
|
/// <exception cref="ObjectDisposedException"></exception>
|
|
public Conversation Fork()
|
|
{
|
|
AssertNotDisposed();
|
|
|
|
// Create a new conversation which references the current position in this one
|
|
var c = new Conversation(Executor, Executor.GetNextSequenceId())
|
|
{
|
|
// Because these values are copied to the forked conversation it means that it will share the exact same output
|
|
// logits next time sampling is done. This is a problem, because the sampling process is allowed to modify those
|
|
// logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
|
|
// they both copy the logits before the next sampling run, to fix this issue.
|
|
_requiredEpoch = _requiredEpoch,
|
|
_batchSampleIndex = _batchSampleIndex,
|
|
_forked = true,
|
|
|
|
_end = _end,
|
|
};
|
|
|
|
// Setting this flag means that logits will be copied next time sampling is called, ensuring that the forked
|
|
// conversation doesn't share logits with this one.
|
|
_forked = true;
|
|
|
|
// Assign tokens to the new sequence
|
|
NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end);
|
|
|
|
return c;
|
|
}
|
|
|
|
#region sample
|
|
/// <summary>
|
|
/// Get the logits from this conversation, ready for sampling
|
|
/// </summary>
|
|
/// <returns></returns>
|
|
/// <exception cref="ObjectDisposedException"></exception>
|
|
/// <exception cref="CannotSampleRequiresPromptException">Thrown if this conversation was not prompted before the previous call to infer</exception>
|
|
/// <exception cref="CannotSampleRequiresInferenceException">Thrown if Infer() must be called on the executor</exception>
|
|
public Span<float> Sample()
|
|
{
|
|
AssertNotDisposed();
|
|
|
|
if (_requiredEpoch < Executor.Epoch)
|
|
throw new CannotSampleRequiresPromptException();
|
|
if (_requiredEpoch > Executor.Epoch)
|
|
throw new CannotSampleRequiresInferenceException();
|
|
|
|
var span = Executor.Context.NativeHandle.GetLogitsIth(_batchSampleIndex);
|
|
|
|
// If necessary copy the span, to protect it from modification. This is only done when
|
|
// this conversation has been forked in this epoch.
|
|
if (_forked)
|
|
span = span.ToArray();
|
|
|
|
return span;
|
|
}
|
|
#endregion
|
|
|
|
#region prompt
|
|
private void AssertCanBePrompted()
|
|
{
|
|
AssertNotDisposed();
|
|
|
|
if (RequiresInference)
|
|
throw new AlreadyPromptedConversationException();
|
|
}
|
|
|
|
/// <summary>
|
|
/// Add tokens to this conversation
|
|
/// </summary>
|
|
/// <param name="input"></param>
|
|
/// <returns></returns>
|
|
[Obsolete("Tokenize the text and pass the tokens instead")]
|
|
public void Prompt(string input, bool addBos, bool special)
|
|
{
|
|
AssertCanBePrompted();
|
|
|
|
Prompt(Executor.Context.Tokenize(input, addBos, special));
|
|
}
|
|
|
|
/// <summary>
|
|
/// Add tokens to this conversation
|
|
/// </summary>
|
|
/// <param name="tokens"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ObjectDisposedException"></exception>
|
|
/// <exception cref="AlreadyPromptedConversationException"></exception>
|
|
public void Prompt(List<LLamaToken> tokens)
|
|
{
|
|
AssertCanBePrompted();
|
|
|
|
#if NET6_0_OR_GREATER
|
|
var span = CollectionsMarshal.AsSpan(tokens);
|
|
Prompt(span);
|
|
#else
|
|
// Borrow an array and copy tokens into it
|
|
var arr = ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
|
|
try
|
|
{
|
|
for (var i = 0; i < tokens.Count; i++)
|
|
arr[i] = tokens[i];
|
|
|
|
Prompt(arr.AsSpan());
|
|
}
|
|
finally
|
|
{
|
|
ArrayPool<LLamaToken>.Shared.Return(arr);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
/// <summary>
|
|
/// Add tokens to this conversation
|
|
/// </summary>
|
|
/// <param name="tokens"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ObjectDisposedException"></exception>
|
|
/// <exception cref="AlreadyPromptedConversationException"></exception>
|
|
public void Prompt(ReadOnlySpan<LLamaToken> tokens)
|
|
{
|
|
AssertCanBePrompted();
|
|
|
|
// No point doing anything if there is no actual prompt!
|
|
if (tokens.Length == 0)
|
|
return;
|
|
|
|
// Add the prompt to the batch
|
|
for (var i = 0; i < tokens.Length; i++)
|
|
_batchSampleIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
|
|
|
|
// Mark this conversation as needing inference/sampling
|
|
_requiredEpoch = Executor.Epoch + 1;
|
|
|
|
// Unset the forked flag. Since this conversation has just been prompted it's no longer
|
|
// sharing anything with any other conversations.
|
|
_forked = false;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Add a single token to this conversation
|
|
/// </summary>
|
|
/// <param name="token"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ObjectDisposedException"></exception>
|
|
/// <exception cref="AlreadyPromptedConversationException"></exception>
|
|
public void Prompt(LLamaToken token)
|
|
{
|
|
AssertCanBePrompted();
|
|
|
|
unsafe
|
|
{
|
|
Span<LLamaToken> span = stackalloc LLamaToken[1] { token };
|
|
Prompt(span);
|
|
}
|
|
}
|
|
#endregion
|
|
|
|
#region modify
|
|
/// <summary>
|
|
/// Directly modify the KV cache of this conversation
|
|
/// </summary>
|
|
/// <param name="modifier"></param>
|
|
/// <exception cref="CannotModifyWhileRequiresInferenceException">Thrown if this method is called while <see cref="Conversation.RequiresInference"/> == true</exception>
|
|
public void Modify(ModifyKvCache modifier)
|
|
{
|
|
AssertNotDisposed();
|
|
|
|
if (RequiresInference)
|
|
throw new CannotModifyWhileRequiresInferenceException();
|
|
|
|
// do whatever the modification is
|
|
_end = modifier.Invoke(_end, new KvAccessor(this));
|
|
|
|
// Set the epoch down to zero, this ensures that this conversation
|
|
// cannot be sampled until it is prompted again.
|
|
_requiredEpoch = 0;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Provides direct access to the KV cache of a <see cref="Conversation"/>.
|
|
/// See <see cref="Modify"/> for how to use this.
|
|
/// </summary>
|
|
public readonly ref struct KvAccessor
|
|
{
|
|
private readonly Conversation _conversation;
|
|
|
|
internal KvAccessor(Conversation conversation)
|
|
{
|
|
_conversation = conversation;
|
|
}
|
|
|
|
#region remove
|
|
/// <summary>
|
|
/// Removes all tokens that have positions in [start, end)
|
|
/// </summary>
|
|
/// <param name="start">Start position (inclusive)</param>
|
|
/// <param name="end">End position (exclusive)</param>
|
|
public void Remove(LLamaPos start, LLamaPos end)
|
|
{
|
|
_conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Removes all tokens starting from the given position
|
|
/// </summary>
|
|
/// <param name="start">Start position (inclusive)</param>
|
|
/// <param name="count">Number of tokens</param>
|
|
public void Remove(LLamaPos start, int count)
|
|
{
|
|
if (count <= 0)
|
|
return;
|
|
|
|
var end = start.Value + count;
|
|
_conversation.Executor.Context.NativeHandle.KvCacheRemove(_conversation.ConversationId, start, end);
|
|
}
|
|
#endregion
|
|
|
|
#region shift
|
|
/// <summary>
|
|
/// Adds relative position "delta" to all tokens that have positions in [p0, p1).
|
|
/// If the KV cache is RoPEd, the KV data is updated
|
|
/// accordingly
|
|
/// </summary>
|
|
/// <param name="start">Start position (inclusive)</param>
|
|
/// <param name="end">End position (exclusive)</param>
|
|
/// <param name="delta">Amount to add on to each token position</param>
|
|
public void Add(LLamaPos start, LLamaPos end, int delta)
|
|
{
|
|
_conversation.Executor.Context.NativeHandle.KvCacheSequenceAdd(_conversation.ConversationId, start, end, delta);
|
|
}
|
|
#endregion
|
|
|
|
#region divide
|
|
/// <summary>
|
|
/// Integer division of the positions by factor of `d > 1`.
|
|
/// If the KV cache is RoPEd, the KV data is updated accordingly.
|
|
/// </summary>
|
|
/// <param name="start">Start position (inclusive). If less than zero, it is clamped to zero.</param>
|
|
/// <param name="end">End position (exclusive). If less than zero, it is treated as "infinity".</param>
|
|
/// <param name="divisor">Amount to divide each position by.</param>
|
|
public void Divide(LLamaPos start, LLamaPos end, int divisor)
|
|
{
|
|
if (divisor <= 0)
|
|
throw new ArgumentOutOfRangeException(nameof(divisor));
|
|
|
|
_conversation.Executor.Context.NativeHandle.KvCacheSequenceDivide(_conversation.ConversationId, start, end, divisor);
|
|
}
|
|
#endregion
|
|
}
|
|
|
|
/// <summary>
|
|
/// A function which can temporarily access the KV cache of a <see cref="Conversation"/> to modify it directly
|
|
/// </summary>
|
|
/// <param name="end">The current end token of this conversation</param>
|
|
/// <param name="kv">An <see cref="KvAccessor"/> which allows direct access to modify the KV cache</param>
|
|
/// <returns>The new end token position</returns>
|
|
public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv);
|
|
#endregion
|
|
|
|
#region save/load
|
|
private void AssertCanLoad()
|
|
{
|
|
AssertNotDisposed();
|
|
if (_end.Value > 0)
|
|
throw new InvalidOperationException("Cannot load into a non-empty conversation");
|
|
}
|
|
|
|
private void AssertCanSave()
|
|
{
|
|
AssertNotDisposed();
|
|
if (RequiresInference)
|
|
throw new CannotSaveWhileRequiresInferenceException();
|
|
}
|
|
|
|
|
|
/// <summary>
|
|
/// Save the complete state of this conversation to a file. if the file already exists it will be overwritten.
|
|
/// </summary>
|
|
/// <param name="filepath"></param>
|
|
/// <exception cref="CannotSaveWhileRequiresInferenceException"></exception>
|
|
public void Save(string filepath)
|
|
{
|
|
AssertCanSave();
|
|
|
|
// Prepare extra state to put into file header
|
|
var state = GetState();
|
|
var bytes = JsonSerializer.SerializeToUtf8Bytes(state);
|
|
|
|
// Save extra state along with the KV cache
|
|
Executor.Context.SaveState(filepath, ConversationId, bytes);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Save the complete state of this conversation in system memory.
|
|
/// </summary>
|
|
/// <returns></returns>
|
|
public State Save()
|
|
{
|
|
AssertCanSave();
|
|
|
|
return new PrivateState(
|
|
Executor.Context.GetState(ConversationId),
|
|
GetState()
|
|
);
|
|
}
|
|
|
|
|
|
/// <summary>
|
|
/// Load state from a file
|
|
/// This should only ever be called by the BatchedExecutor, on a newly created conversation object!
|
|
/// </summary>
|
|
/// <param name="filepath"></param>
|
|
/// <exception cref="InvalidOperationException"></exception>
|
|
internal void Load(string filepath)
|
|
{
|
|
AssertCanLoad();
|
|
|
|
// Load the state from file into the KV cache
|
|
Executor.Context.LoadState(filepath, ConversationId, out var header);
|
|
|
|
// deserialize the extra state in the file header
|
|
var state = JsonSerializer.Deserialize<SerializableConversationState>(header);
|
|
if (state == null)
|
|
{
|
|
Dispose();
|
|
throw new InvalidOperationException("Failed to deserialize - deserialized header state was null");
|
|
}
|
|
|
|
Load(state);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Load state from a previously saved state.
|
|
/// This should only ever be called by the BatchedExecutor, on a newly created conversation object!
|
|
/// </summary>
|
|
/// <param name="state"></param>
|
|
internal void Load(State state)
|
|
{
|
|
AssertCanLoad();
|
|
|
|
// There is only one class that extends State and it is PrivateState, so this cast is safe.
|
|
var priv = (PrivateState)state;
|
|
|
|
// Load the state from file into the KV cache
|
|
Executor.Context.LoadState(priv.SequenceState, ConversationId);
|
|
|
|
Load(priv.ConversationState);
|
|
}
|
|
|
|
|
|
private void Load(SerializableConversationState state)
|
|
{
|
|
if (state.Version != 1)
|
|
throw new InvalidOperationException("Failed to deserialize - mismatched version number");
|
|
|
|
// Load extra conversation state
|
|
_end = state.TokenCount;
|
|
}
|
|
|
|
private SerializableConversationState GetState()
|
|
{
|
|
return new SerializableConversationState(
|
|
Version: 1,
|
|
TokenCount: TokenCount
|
|
);
|
|
}
|
|
|
|
|
|
private record SerializableConversationState(int Version, int TokenCount);
|
|
|
|
private sealed class PrivateState
|
|
: State
|
|
{
|
|
public readonly LLamaContext.SequenceState SequenceState;
|
|
public readonly SerializableConversationState ConversationState;
|
|
|
|
public override ulong Size => SequenceState.Size;
|
|
|
|
public PrivateState(LLamaContext.SequenceState sequenceState, SerializableConversationState conversationState)
|
|
{
|
|
SequenceState = sequenceState;
|
|
ConversationState = conversationState;
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public override void Dispose()
|
|
{
|
|
if (IsDisposed)
|
|
throw new ObjectDisposedException(nameof(State));
|
|
IsDisposed = true;
|
|
|
|
SequenceState.Dispose();
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// In memory saved state of a <see cref="Conversation"/>
|
|
/// </summary>
|
|
public abstract class State
|
|
: IDisposable
|
|
{
|
|
/// <summary>
|
|
/// Indicates if this state has been disposed
|
|
/// </summary>
|
|
public bool IsDisposed { get; protected set; }
|
|
|
|
/// <summary>
|
|
/// Get the size in bytes of this state object
|
|
/// </summary>
|
|
public abstract ulong Size { get; }
|
|
|
|
/// <inheritdoc />
|
|
public abstract void Dispose();
|
|
|
|
/// <summary>
|
|
/// Internal constructor prevent anyone outside of LLamaSharp extending this class
|
|
/// </summary>
|
|
internal State()
|
|
{
|
|
}
|
|
}
|
|
#endregion
|
|
} |