BatchedExecutor Save/Load (#681)
* Added the ability to save and load individual conversations in a batched executor. - New example - Added `BatchedExecutor.Load(filepath)` method - Added `Conversation.Save(filepath)` method - Added new (currently internal) `SaveState`/`LoadState` methods in LLamaContext which can stash some extra binary data in the header * Added ability to save/load a `Conversation` to an in-memory state, instead of to file. * Moved the new save/load methods out to an extension class specifically for the batched executor. * Removed unnecessary spaces
This commit is contained in:
parent
f01c13ee54
commit
ccc49eb1e0
|
@ -26,6 +26,7 @@ public class ExampleRunner
|
|||
{ "Semantic Kernel: Prompt", SemanticKernelPrompt.Run },
|
||||
{ "Semantic Kernel: Chat", SemanticKernelChat.Run },
|
||||
{ "Semantic Kernel: Store", SemanticKernelMemory.Run },
|
||||
{ "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run },
|
||||
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
|
||||
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
|
||||
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
using LLama.Batched;
|
||||
using LLama.Common;
|
||||
using LLama.Native;
|
||||
using LLama.Sampling;
|
||||
using Spectre.Console;
|
||||
|
||||
namespace LLama.Examples.Examples;
|
||||
|
||||
/// <summary>
|
||||
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
|
||||
/// </summary>
|
||||
public class BatchedExecutorSaveAndLoad
|
||||
{
|
||||
private const int n_len = 18;
|
||||
|
||||
public static async Task Run()
|
||||
{
|
||||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
|
||||
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
|
||||
|
||||
// Create an executor that can evaluate a batch of conversations together
|
||||
using var executor = new BatchedExecutor(model, parameters);
|
||||
|
||||
// Print some info
|
||||
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
|
||||
Console.WriteLine($"Created executor with model: {name}");
|
||||
|
||||
// Create a conversation
|
||||
var conversation = executor.Create();
|
||||
conversation.Prompt(prompt);
|
||||
|
||||
// Run inference loop
|
||||
var decoder = new StreamingTokenDecoder(executor.Context);
|
||||
var sampler = new DefaultSamplingPipeline();
|
||||
var lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len);
|
||||
|
||||
// Can't save a conversation while RequiresInference is true
|
||||
if (conversation.RequiresInference)
|
||||
await executor.Infer();
|
||||
|
||||
// Save this conversation to a file and dispose it
|
||||
conversation.Save("demo_conversation.state");
|
||||
conversation.Dispose();
|
||||
AnsiConsole.WriteLine($"Saved state: {new FileInfo("demo_conversation.state").Length} bytes");
|
||||
|
||||
// Now create a new conversation by loading that state
|
||||
conversation = executor.Load("demo_conversation.state");
|
||||
AnsiConsole.WriteLine("Loaded state");
|
||||
|
||||
// Prompt it again with the last token, so we can continue generating
|
||||
conversation.Rewind(1);
|
||||
conversation.Prompt(lastToken);
|
||||
|
||||
// Continue generating text
|
||||
lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len);
|
||||
|
||||
// Can't save a conversation while RequiresInference is true
|
||||
if (conversation.RequiresInference)
|
||||
await executor.Infer();
|
||||
|
||||
// Save the conversation again, this time into system memory
|
||||
using (var state = conversation.Save())
|
||||
{
|
||||
conversation.Dispose();
|
||||
AnsiConsole.WriteLine($"Saved state to memory: {state.Size} bytes");
|
||||
|
||||
// Now create a new conversation by loading that state
|
||||
conversation = executor.Load("demo_conversation.state");
|
||||
AnsiConsole.WriteLine("Loaded state");
|
||||
}
|
||||
|
||||
// Prompt it again with the last token, so we can continue generating
|
||||
conversation.Rewind(1);
|
||||
conversation.Prompt(lastToken);
|
||||
|
||||
// Continue generating text
|
||||
await GenerateTokens(executor, conversation, sampler, decoder, n_len);
|
||||
|
||||
// Display final ouput
|
||||
AnsiConsole.MarkupLine($"[red]{prompt}{decoder.Read()}[/]");
|
||||
}
|
||||
|
||||
private static async Task<LLamaToken> GenerateTokens(BatchedExecutor executor, Conversation conversation, ISamplingPipeline sampler, StreamingTokenDecoder decoder, int count = 15)
|
||||
{
|
||||
var token = (LLamaToken)0;
|
||||
|
||||
for (var i = 0; i < count; i++)
|
||||
{
|
||||
// Run inference
|
||||
await executor.Infer();
|
||||
|
||||
// Use sampling pipeline to pick a token
|
||||
token = sampler.Sample(executor.Context.NativeHandle, conversation.Sample(), ReadOnlySpan<LLamaToken>.Empty);
|
||||
|
||||
// Add it to the decoder, so it can be converted into text later
|
||||
decoder.Add(token);
|
||||
|
||||
// Prompt the conversation with the token
|
||||
conversation.Prompt(token);
|
||||
}
|
||||
|
||||
return token;
|
||||
}
|
||||
}
|
|
@ -84,6 +84,39 @@ public sealed class BatchedExecutor
|
|||
return new Conversation(this, GetNextSequenceId());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load a conversation that was previously saved to a file. Once loaded the conversation will
|
||||
/// need to be prompted.
|
||||
/// </summary>
|
||||
/// <param name="filepath"></param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="ObjectDisposedException"></exception>
|
||||
public Conversation Load(string filepath)
|
||||
{
|
||||
if (IsDisposed)
|
||||
throw new ObjectDisposedException(nameof(BatchedExecutor));
|
||||
|
||||
var conversation = Create();
|
||||
conversation.Load(filepath);
|
||||
return conversation;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load a conversation that was previously saved into memory. Once loaded the conversation will need to be prompted.
|
||||
/// </summary>
|
||||
/// <param name="state"></param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="ObjectDisposedException"></exception>
|
||||
public Conversation Load(Conversation.State state)
|
||||
{
|
||||
if (IsDisposed)
|
||||
throw new ObjectDisposedException(nameof(BatchedExecutor));
|
||||
|
||||
var conversation = Create();
|
||||
conversation.Load(state);
|
||||
return conversation;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Run inference for all conversations in the batch which have pending tokens.
|
||||
///
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text.Json;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Batched;
|
||||
|
@ -14,7 +15,7 @@ public sealed class Conversation
|
|||
{
|
||||
private ulong _requiredEpoch;
|
||||
private LLamaPos _end;
|
||||
private int _batchIndex;
|
||||
private int _batchSampleIndex;
|
||||
private bool _disposed;
|
||||
private bool _forked;
|
||||
|
||||
|
@ -107,7 +108,7 @@ public sealed class Conversation
|
|||
// logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
|
||||
// they both copy the logits before the next sampling run, to fix this issue.
|
||||
_requiredEpoch = _requiredEpoch,
|
||||
_batchIndex = _batchIndex,
|
||||
_batchSampleIndex = _batchSampleIndex,
|
||||
_forked = true,
|
||||
|
||||
_end = _end,
|
||||
|
@ -140,7 +141,7 @@ public sealed class Conversation
|
|||
if (_requiredEpoch > Executor.Epoch)
|
||||
throw new CannotSampleRequiresInferenceException();
|
||||
|
||||
var span = Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
|
||||
var span = Executor.Context.NativeHandle.GetLogitsIth(_batchSampleIndex);
|
||||
|
||||
// If necessary copy the span, to protect it from modification. This is only done when
|
||||
// this conversation has been forked in this epoch.
|
||||
|
@ -220,7 +221,7 @@ public sealed class Conversation
|
|||
|
||||
// Add the prompt to the batch
|
||||
for (var i = 0; i < tokens.Length; i++)
|
||||
_batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
|
||||
_batchSampleIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
|
||||
|
||||
// Mark this conversation as needing inference/sampling
|
||||
_requiredEpoch = Executor.Epoch + 1;
|
||||
|
@ -350,4 +351,168 @@ public sealed class Conversation
|
|||
/// <returns>The new end token position</returns>
|
||||
public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv);
|
||||
#endregion
|
||||
|
||||
#region save/load
|
||||
private void AssertCanLoad()
|
||||
{
|
||||
AssertNotDisposed();
|
||||
if (_end.Value > 0)
|
||||
throw new InvalidOperationException("Cannot load into a non-empty conversation");
|
||||
}
|
||||
|
||||
private void AssertCanSave()
|
||||
{
|
||||
AssertNotDisposed();
|
||||
if (RequiresInference)
|
||||
throw new CannotSaveWhileRequiresInferenceException();
|
||||
}
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Save the complete state of this conversation to a file. if the file already exists it will be overwritten.
|
||||
/// </summary>
|
||||
/// <param name="filepath"></param>
|
||||
/// <exception cref="CannotSaveWhileRequiresInferenceException"></exception>
|
||||
public void Save(string filepath)
|
||||
{
|
||||
AssertCanSave();
|
||||
|
||||
// Prepare extra state to put into file header
|
||||
var state = GetState();
|
||||
var bytes = JsonSerializer.SerializeToUtf8Bytes(state);
|
||||
|
||||
// Save extra state along with the KV cache
|
||||
Executor.Context.SaveState(filepath, ConversationId, bytes);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Save the complete state of this conversation in system memory.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public State Save()
|
||||
{
|
||||
AssertCanSave();
|
||||
|
||||
return new PrivateState(
|
||||
Executor.Context.GetState(ConversationId),
|
||||
GetState()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Load state from a file
|
||||
/// This should only ever be called by the BatchedExecutor, on a newly created conversation object!
|
||||
/// </summary>
|
||||
/// <param name="filepath"></param>
|
||||
/// <exception cref="InvalidOperationException"></exception>
|
||||
internal void Load(string filepath)
|
||||
{
|
||||
AssertCanLoad();
|
||||
|
||||
// Load the state from file into the KV cache
|
||||
Executor.Context.LoadState(filepath, ConversationId, out var header);
|
||||
|
||||
// deserialize the extra state in the file header
|
||||
var state = JsonSerializer.Deserialize<SerializableConversationState>(header);
|
||||
if (state == null)
|
||||
{
|
||||
Dispose();
|
||||
throw new InvalidOperationException("Failed to deserialize - deserialized header state was null");
|
||||
}
|
||||
|
||||
Load(state);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load state from a previously saved state.
|
||||
/// This should only ever be called by the BatchedExecutor, on a newly created conversation object!
|
||||
/// </summary>
|
||||
/// <param name="state"></param>
|
||||
internal void Load(State state)
|
||||
{
|
||||
AssertCanLoad();
|
||||
|
||||
// There is only one class that extends State and it is PrivateState, so this cast is safe.
|
||||
var priv = (PrivateState)state;
|
||||
|
||||
// Load the state from file into the KV cache
|
||||
Executor.Context.LoadState(priv.SequenceState, ConversationId);
|
||||
|
||||
Load(priv.ConversationState);
|
||||
}
|
||||
|
||||
|
||||
private void Load(SerializableConversationState state)
|
||||
{
|
||||
if (state.Version != 1)
|
||||
throw new InvalidOperationException("Failed to deserialize - mismatched version number");
|
||||
|
||||
// Load extra conversation state
|
||||
_end = state.TokenCount;
|
||||
}
|
||||
|
||||
private SerializableConversationState GetState()
|
||||
{
|
||||
return new SerializableConversationState(
|
||||
Version: 1,
|
||||
TokenCount: TokenCount
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
private record SerializableConversationState(int Version, int TokenCount);
|
||||
|
||||
private sealed class PrivateState
|
||||
: State
|
||||
{
|
||||
public readonly LLamaContext.SequenceState SequenceState;
|
||||
public readonly SerializableConversationState ConversationState;
|
||||
|
||||
public override ulong Size => SequenceState.Size;
|
||||
|
||||
public PrivateState(LLamaContext.SequenceState sequenceState, SerializableConversationState conversationState)
|
||||
{
|
||||
SequenceState = sequenceState;
|
||||
ConversationState = conversationState;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void Dispose()
|
||||
{
|
||||
if (IsDisposed)
|
||||
throw new ObjectDisposedException(nameof(State));
|
||||
IsDisposed = true;
|
||||
|
||||
SequenceState.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// In memory saved state of a <see cref="Conversation"/>
|
||||
/// </summary>
|
||||
public abstract class State
|
||||
: IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Indicates if this state has been disposed
|
||||
/// </summary>
|
||||
public bool IsDisposed { get; protected set; }
|
||||
|
||||
/// <summary>
|
||||
/// Get the size in bytes of this state object
|
||||
/// </summary>
|
||||
public abstract ulong Size { get; }
|
||||
|
||||
/// <inheritdoc />
|
||||
public abstract void Dispose();
|
||||
|
||||
/// <summary>
|
||||
/// Internal constructor prevent anyone outside of LLamaSharp extending this class
|
||||
/// </summary>
|
||||
internal State()
|
||||
{
|
||||
}
|
||||
}
|
||||
#endregion
|
||||
}
|
|
@ -56,18 +56,6 @@ public class CannotSampleRequiresPromptException
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This exception is thrown when <see cref="Conversation.Fork"/> is called when <see cref="Conversation.RequiresInference"/> = true
|
||||
/// </summary>
|
||||
public class CannotForkWhileRequiresInferenceException
|
||||
: ExperimentalBatchedExecutorException
|
||||
{
|
||||
internal CannotForkWhileRequiresInferenceException()
|
||||
: base("Cannot `Fork()` a conversation while RequiresInference is true")
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This exception is thrown when <see cref="Conversation.Modify"/> is called when <see cref="Conversation.RequiresInference"/> = true
|
||||
/// </summary>
|
||||
|
@ -78,4 +66,18 @@ public class CannotModifyWhileRequiresInferenceException
|
|||
: base("Cannot `Modify()` a conversation while RequiresInference is true")
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This exception is thrown when "Save()" is called on a <see cref="Conversation"/> which has
|
||||
/// already been prompted and before "Infer()" has been called.
|
||||
/// <see cref="BatchedExecutor"/>.
|
||||
/// </summary>
|
||||
public class CannotSaveWhileRequiresInferenceException
|
||||
: ExperimentalBatchedExecutorException
|
||||
{
|
||||
internal CannotSaveWhileRequiresInferenceException()
|
||||
: base("Must call `Infer()` before saving this Conversation")
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
using System;
|
||||
using System.Buffers.Binary;
|
||||
using System.IO;
|
||||
using System.IO.MemoryMappedFiles;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Batched;
|
||||
|
||||
internal static class LLamaContextExtensions
|
||||
{
|
||||
private const uint FileHeaderMagic = 3430400180;
|
||||
|
||||
/// <summary>
|
||||
/// Save the state of a particular sequence to specified path. Also save some extra data which will be returned when loading.
|
||||
/// Data saved with this method <b>must</b> be saved with <see cref="LoadState(LLamaContext, string, LLamaSeqId, out byte[])"/>
|
||||
/// </summary>
|
||||
/// <param name="context"></param>
|
||||
/// <param name="filename"></param>
|
||||
/// <param name="sequence"></param>
|
||||
/// <param name="header"></param>
|
||||
internal static void SaveState(this LLamaContext context, string filename, LLamaSeqId sequence, ReadOnlySpan<byte> header)
|
||||
{
|
||||
// Delete that file before overwriting it
|
||||
if (File.Exists(filename))
|
||||
File.Delete(filename);
|
||||
|
||||
// Estimate size of state to write to disk, this is always equal to or greater than the actual size
|
||||
var estimatedStateSize = checked((long)context.NativeHandle.GetStateSize(sequence));
|
||||
|
||||
// Space for "extra" byte plus a 8 byte header
|
||||
var prefixSize = header.Length + 8;
|
||||
|
||||
// Add enough space for the "extra" data and a 6 byte header
|
||||
var totalFileSize = prefixSize + estimatedStateSize;
|
||||
|
||||
// Map the file and write the bytes directly to it.
|
||||
long writtenBytes = 0;
|
||||
using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, totalFileSize))
|
||||
{
|
||||
using (var view = file.CreateViewAccessor(0, totalFileSize))
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
byte* ptr = null;
|
||||
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
|
||||
try
|
||||
{
|
||||
// Write prefix data
|
||||
BinaryPrimitives.WriteUInt32BigEndian(new Span<byte>(ptr + writtenBytes, 4), FileHeaderMagic);
|
||||
writtenBytes += 4;
|
||||
BinaryPrimitives.WriteUInt32BigEndian(new Span<byte>(ptr + writtenBytes, 4), (uint)header.Length);
|
||||
writtenBytes += 4;
|
||||
header.CopyTo(new Span<byte>(ptr + writtenBytes, header.Length));
|
||||
writtenBytes += header.Length;
|
||||
|
||||
// Write state data
|
||||
writtenBytes += (long)context.NativeHandle.GetState(ptr + writtenBytes, (ulong)estimatedStateSize, sequence);
|
||||
}
|
||||
finally
|
||||
{
|
||||
view.SafeMemoryMappedViewHandle.ReleasePointer();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Truncate the file to the actual size of data that was written
|
||||
using (var fileStream = new FileStream(filename, FileMode.Open))
|
||||
fileStream.SetLength(writtenBytes);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load the state from the specified path into a particular sequence. Also reading header data. Must only be used with
|
||||
/// data previously saved with <see cref="SaveState(LLamaContext, string, LLamaSeqId, ReadOnlySpan{byte})"/>
|
||||
/// </summary>
|
||||
/// <param name="context"></param>
|
||||
/// <param name="filename"></param>
|
||||
/// <param name="sequence"></param>
|
||||
/// <param name="header"></param>
|
||||
/// <exception cref="InvalidOperationException"></exception>
|
||||
internal static void LoadState(this LLamaContext context, string filename, LLamaSeqId sequence, out byte[] header)
|
||||
{
|
||||
// Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from
|
||||
using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null))
|
||||
using (var view = file.CreateViewAccessor())
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
byte* ptr = null;
|
||||
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
|
||||
try
|
||||
{
|
||||
var readBytes = 0;
|
||||
|
||||
// Read header
|
||||
var magic = BinaryPrimitives.ReadUInt32BigEndian(new ReadOnlySpan<byte>(ptr + readBytes, 4));
|
||||
readBytes += 4;
|
||||
if (magic != FileHeaderMagic)
|
||||
throw new InvalidOperationException("Invalid file header");
|
||||
|
||||
var headerLength = checked((int)BinaryPrimitives.ReadUInt32BigEndian(new ReadOnlySpan<byte>(ptr + readBytes, 4)));
|
||||
readBytes += 4;
|
||||
|
||||
header = new byte[headerLength];
|
||||
new Span<byte>(ptr + readBytes, headerLength).CopyTo(header);
|
||||
readBytes += headerLength;
|
||||
|
||||
context.NativeHandle.SetState(ptr + readBytes, sequence);
|
||||
}
|
||||
finally
|
||||
{
|
||||
view.SafeMemoryMappedViewHandle.ReleasePointer();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue