Polymorphic serialization for executor state and transforms

This commit is contained in:
eublefar 2024-03-17 15:34:36 +01:00
parent 6f76d77350
commit a31391edd7
10 changed files with 302 additions and 50 deletions

View File

@ -61,6 +61,12 @@ public class ChatSessionWithHistory
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session saved.");
}
else if (userInput == "load")
{
session.LoadSession("Assets/chat-with-bob");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session loaded.");
}
else if (userInput == "regenerate")
{
Console.ForegroundColor = ConsoleColor.Yellow;

View File

@ -37,7 +37,8 @@ public class ChatSessionWithRestart
};
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");
Console.WriteLine("The chat session has started. Write `save` to save session in memory."
+ " Write `reset` to start from the last saved checkpoint");
// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
@ -48,13 +49,13 @@ public class ChatSessionWithRestart
if(userInput == "reset")
{
session.LoadSession(resetState);
Console.WriteLine($"History: {session.HistoryTransform.HistoryToText(session.History)}");
Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.History)}");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session reset.");
}
else if (userInput == "save")
{
session.SaveSession("Assets/chat-with-bob");
resetState = session.GetSessionState();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session saved.");
}

View File

@ -1,10 +1,12 @@
using LLama.Common;
using System.Text.Json.Serialization;
namespace LLama.Abstractions
{
/// <summary>
/// Transform history to plain text and vice versa.
/// </summary>
[JsonConverter(typeof(PolymorphicJSONConverter<IHistoryTransform>))]
public interface IHistoryTransform
{
/// <summary>

View File

@ -1,10 +1,13 @@
using System.Collections.Generic;
using LLama.Common;
using System.Collections.Generic;
using System.Text.Json.Serialization;
namespace LLama.Abstractions
{
/// <summary>
/// Takes a stream of tokens and transforms them.
/// </summary>
[JsonConverter(typeof(PolymorphicJSONConverter<ITextStreamTransform>))]
public interface ITextStreamTransform
{
/// <summary>

View File

@ -1,4 +1,7 @@
namespace LLama.Abstractions
using System.Text.Json.Serialization;
using LLama.Common;
namespace LLama.Abstractions
{
/// <summary>
/// An interface for text transformations.
@ -9,6 +12,7 @@
/// - Trimming
/// - etc.
/// </summary>
[JsonConverter(typeof(PolymorphicJSONConverter<ITextTransform>))]
public interface ITextTransform
{
/// <summary>

View File

@ -8,7 +8,6 @@ using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Common;
using static LLama.Common.ChatHistory;
using static LLama.InteractiveExecutor;
using static LLama.LLamaContext;
using static LLama.StatefulExecutorBase;
@ -20,9 +19,30 @@ namespace LLama;
/// </summary>
public class ChatSession
{
private const string _modelStateFilename = "ModelState.st";
private const string _executorStateFilename = "ExecutorState.json";
private const string _hsitoryFilename = "ChatHistory.json";
/// <summary>
/// The filename for the serialized model state (KV cache, etc).
/// </summary>
public const string MODEL_STATE_FILENAME = "ModelState.st";
/// <summary>
/// The filename for the serialized executor state.
/// </summary>
public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json";
/// <summary>
/// The filename for the serialized chat history.
/// </summary>
public const string HISTORY_STATE_FILENAME = "ChatHistory.json";
/// <summary>
/// The filename for the serialized input transform pipeline.
/// </summary>
public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json";
/// <summary>
/// The filename for the serialized output transform.
/// </summary>
public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json";
/// <summary>
/// The filename for the serialized history transform.
/// </summary>
public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json";
/// <summary>
/// The executor for this session.
@ -134,26 +154,7 @@ public class ChatSession
/// <exception cref="ArgumentException"></exception>
public void SaveSession(string path)
{
if (string.IsNullOrWhiteSpace(path))
{
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
}
if (Directory.Exists(path))
{
Directory.Delete(path, recursive: true);
}
Directory.CreateDirectory(path);
string modelStateFilePath = Path.Combine(path, _modelStateFilename);
Executor.Context.SaveState(modelStateFilePath);
string executorStateFilepath = Path.Combine(path, _executorStateFilename);
((StatefulExecutorBase)Executor).SaveState(executorStateFilepath);
string historyFilepath = Path.Combine(path, _hsitoryFilename);
File.WriteAllText(historyFilepath, History.ToJson());
GetSessionState().Save(path);
}
/// <summary>
@ -202,26 +203,14 @@ public class ChatSession
/// <exception cref="ArgumentException"></exception>
public void LoadSession(string path)
{
if (string.IsNullOrWhiteSpace(path))
var state = SessionState.Load(path);
// Handle non-polymorphic serialization of executor state
if (state.ExecutorState is ExecutorBaseState)
{
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME);
((StatefulExecutorBase) Executor).LoadState(filename: executorPath);
}
if (!Directory.Exists(path))
{
throw new ArgumentException("Directory does not exist", nameof(path));
}
string modelStateFilePath = Path.Combine(path, _modelStateFilename);
Executor.Context.LoadState(modelStateFilePath);
string executorStateFilepath = Path.Combine(path, _executorStateFilename);
((StatefulExecutorBase)Executor).LoadState(executorStateFilepath);
string historyFilepath = Path.Combine(path, _hsitoryFilename);
string historyJson = File.ReadAllText(historyFilepath);
History = ChatHistory.FromJson(historyJson)
?? throw new ArgumentException("History file is invalid", nameof(path));
LoadSession(state);
}
/// <summary>
@ -615,7 +604,7 @@ public record SessionState
/// <summary>
/// The the chat history messages for this session.
/// </summary>
public Message[] History { get; set; } = Array.Empty<Message>();
public ChatHistory.Message[] History { get; set; } = Array.Empty<ChatHistory.Message>();
/// <summary>
/// Create a new session state.
@ -638,4 +627,124 @@ public record SessionState
OutputTransform = outputTransform.Clone();
HistoryTransform = historyTransform.Clone();
}
/// <summary>
/// Save the session state to folder.
/// </summary>
/// <param name="path"></param>
public void Save(string path)
{
if (string.IsNullOrWhiteSpace(path))
{
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
}
if (Directory.Exists(path))
{
Directory.Delete(path, recursive: true);
}
Directory.CreateDirectory(path);
string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
var bytes = ContextState.ToByteArray();
File.WriteAllBytes(modelStateFilePath, bytes);
string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState));
string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson());
string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline));
string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform));
string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform));
}
/// <summary>
/// Load the session state from folder.
/// </summary>
/// <param name="path"></param>
/// <returns></returns>
/// <exception cref="ArgumentException">Throws when session state is incorrect</exception>
public static SessionState Load(string path)
{
if (string.IsNullOrWhiteSpace(path))
{
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
}
if (!Directory.Exists(path))
{
throw new ArgumentException("Directory does not exist", nameof(path));
}
string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
var contextState = State.FromByteArray(File.ReadAllBytes(modelStateFilePath));
string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
var executorState = JsonSerializer.Deserialize<ExecutorBaseState>(File.ReadAllText(executorStateFilepath))
?? throw new ArgumentException("Executor state file is invalid", nameof(path));
string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
string historyJson = File.ReadAllText(historyFilepath);
var history = ChatHistory.FromJson(historyJson)
?? throw new ArgumentException("History file is invalid", nameof(path));
string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
ITextTransform[] inputTransforms;
try
{
inputTransforms = File.Exists(inputTransformFilepath) ?
(JsonSerializer.Deserialize<ITextTransform[]>(File.ReadAllText(inputTransformFilepath))
?? throw new ArgumentException("Input transform file is invalid", nameof(path)))
: Array.Empty<ITextTransform>();
}
catch (JsonException)
{
throw new ArgumentException("Input transform file is invalid", nameof(path));
}
string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
ITextStreamTransform outputTransform;
try
{
outputTransform = File.Exists(outputTransformFilepath) ?
(JsonSerializer.Deserialize<ITextStreamTransform>(File.ReadAllText(outputTransformFilepath))
?? throw new ArgumentException("Output transform file is invalid", nameof(path)))
: new LLamaTransforms.EmptyTextOutputStreamTransform();
}
catch (JsonException)
{
throw new ArgumentException("Output transform file is invalid", nameof(path));
}
string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
IHistoryTransform historyTransform;
try
{
historyTransform = File.Exists(historyTransformFilepath) ?
(JsonSerializer.Deserialize<IHistoryTransform>(File.ReadAllText(historyTransformFilepath))
?? throw new ArgumentException("History transform file is invalid", nameof(path)))
: new LLamaTransforms.DefaultHistoryTransform();
}
catch (JsonException)
{
throw new ArgumentException("History transform file is invalid", nameof(path));
}
return new SessionState(
contextState,
executorState,
history,
inputTransforms.ToList(),
outputTransform,
historyTransform);
}
}

View File

@ -0,0 +1,57 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace LLama.Common
{
internal class PolymorphicJSONConverter<T> : JsonConverter<T>
{
public override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
if (reader.TokenType != JsonTokenType.StartObject)
throw new JsonException();
reader.Read();
if (reader.TokenType != JsonTokenType.PropertyName)
throw new JsonException();
string? propertyName = reader.GetString();
if (propertyName != "Name")
throw new JsonException();
reader.Read();
if (reader.TokenType != JsonTokenType.String)
throw new JsonException();
string? name = reader.GetString() ?? throw new JsonException();
var inheritedTypes = Assembly.GetExecutingAssembly().GetTypes().Where(
t => typeof(T).IsAssignableFrom(t) && !t.IsAbstract && !t.IsInterface
);
var type = inheritedTypes.FirstOrDefault(t => t.Name == name);
if (type == null)
throw new JsonException();
reader.Read();
if (reader.TokenType != JsonTokenType.PropertyName)
throw new JsonException();
propertyName = reader.GetString();
if (propertyName != "Data")
throw new JsonException();
var data = JsonSerializer.Deserialize(ref reader, type, options);
if (data == null)
throw new JsonException();
reader.Read();
reader.Read();
return (T)data;
}
public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
{
writer.WriteStartObject();
writer.WriteString("Name", value.GetType().Name);
writer.WritePropertyName("Data");
JsonSerializer.Serialize(writer, value, value.GetType(), options);
writer.WriteEndObject();
}
}
}

View File

@ -166,7 +166,7 @@ namespace LLama
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
// Wrap memory in a "state"
var state = new State(memory);
var state = new State(memory, actualSize);
// Set memory to zero, to prevent it being freed in finally block
memory = IntPtr.Zero;
@ -384,9 +384,12 @@ namespace LLama
public class State
: SafeLLamaHandleBase
{
internal State(IntPtr memory)
private ulong _size;
internal State(IntPtr memory, ulong size)
: base(memory, true)
{
_size = size;
}
/// <inheritdoc />
@ -395,6 +398,29 @@ namespace LLama
Marshal.FreeHGlobal(handle);
return true;
}
/// <summary>
/// Convert this state to a byte array
/// </summary>
/// <returns></returns>
public byte[] ToByteArray()
{
var bytes = new byte[_size];
Marshal.Copy(handle, bytes, 0, (int)_size);
return bytes;
}
/// <summary>
/// Load state from a byte array
/// </summary>
/// <param name="bytes"></param>
/// <returns></returns>
public static State FromByteArray(byte[] bytes)
{
var memory = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, memory, bytes.Length);
return new State(memory, (ulong)bytes.Length);
}
}
}
}

View File

@ -370,6 +370,7 @@ namespace LLama
public bool NeedToSaveSession { get; set; }
}
[JsonConverter(typeof(PolymorphicJSONConverter<ExecutorBaseState>))]
public class ExecutorBaseState
{
[JsonPropertyName("n_past")]

View File

@ -3,6 +3,7 @@ using LLama.Common;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;
namespace LLama
{
@ -29,6 +30,12 @@ namespace LLama
private readonly string _unknownName;
private readonly bool _isInstructMode;
public string UserName => _userName;
public string AssistantName => _assistantName;
public string SystemName => _systemName;
public string UnknownName => _unknownName;
public bool IsInstructMode => _isInstructMode;
/// <summary>
///
/// </summary>
@ -158,6 +165,42 @@ namespace LLama
private readonly int _maxKeywordLength;
private readonly bool _removeAllMatchedTokens;
/// <summary>
/// Keywords that you want to remove from the response.
/// This property is used for JSON serialization.
/// </summary>
[JsonPropertyName("keywords")]
public HashSet<string> Keywords => _keywords;
/// <summary>
/// Maximum length of the keywords.
/// This property is used for JSON serialization.
/// </summary>
[JsonPropertyName("maxKeywordLength")]
public int MaxKeywordLength => _maxKeywordLength;
/// <summary>
/// If set to true, when getting a matched keyword, all the related tokens will be removed.
/// Otherwise only the part of keyword will be removed.
/// This property is used for JSON serialization.
/// </summary>
[JsonPropertyName("removeAllMatchedTokens")]
public bool RemoveAllMatchedTokens => _removeAllMatchedTokens;
/// <summary>
/// JSON constructor.
/// </summary>
[JsonConstructor]
public KeywordTextOutputStreamTransform(
HashSet<string> keywords,
int maxKeywordLength,
bool removeAllMatchedTokens)
{
_keywords = new(keywords);
_maxKeywordLength = maxKeywordLength;
_removeAllMatchedTokens = removeAllMatchedTokens;
}
/// <summary>
///
/// </summary>