diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs
index 17908908..6a84d2fd 100644
--- a/LLama.Examples/Examples/ChatSessionWithHistory.cs
+++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs
@@ -61,6 +61,12 @@ public class ChatSessionWithHistory
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session saved.");
}
+ else if (userInput == "load")
+ {
+ session.LoadSession("Assets/chat-with-bob");
+ Console.ForegroundColor = ConsoleColor.Yellow;
+ Console.WriteLine("Session loaded.");
+ }
else if (userInput == "regenerate")
{
Console.ForegroundColor = ConsoleColor.Yellow;
diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs
index 3462c506..234bac3c 100644
--- a/LLama.Examples/Examples/ChatSessionWithRestart.cs
+++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs
@@ -37,7 +37,8 @@ public class ChatSessionWithRestart
};
Console.ForegroundColor = ConsoleColor.Yellow;
- Console.WriteLine("The chat session has started.");
+ Console.WriteLine("The chat session has started. Write `save` to save session in memory."
+ + " Write `reset` to start from the last saved checkpoint");
// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
@@ -48,13 +49,13 @@ public class ChatSessionWithRestart
if(userInput == "reset")
{
session.LoadSession(resetState);
- Console.WriteLine($"History: {session.HistoryTransform.HistoryToText(session.History)}");
+ Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.History)}");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session reset.");
}
else if (userInput == "save")
{
- session.SaveSession("Assets/chat-with-bob");
+ resetState = session.GetSessionState();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session saved.");
}
diff --git a/LLama/Abstractions/IHistoryTransform.cs b/LLama/Abstractions/IHistoryTransform.cs
index 5651343f..9644b3e1 100644
--- a/LLama/Abstractions/IHistoryTransform.cs
+++ b/LLama/Abstractions/IHistoryTransform.cs
@@ -1,10 +1,12 @@
using LLama.Common;
+using System.Text.Json.Serialization;
namespace LLama.Abstractions
{
///
/// Transform history to plain text and vice versa.
///
+ [JsonConverter(typeof(PolymorphicJSONConverter))]
public interface IHistoryTransform
{
///
diff --git a/LLama/Abstractions/ITextStreamTransform.cs b/LLama/Abstractions/ITextStreamTransform.cs
index 2b63299d..3ebdba67 100644
--- a/LLama/Abstractions/ITextStreamTransform.cs
+++ b/LLama/Abstractions/ITextStreamTransform.cs
@@ -1,10 +1,13 @@
-using System.Collections.Generic;
+using LLama.Common;
+using System.Collections.Generic;
+using System.Text.Json.Serialization;
namespace LLama.Abstractions
{
///
/// Takes a stream of tokens and transforms them.
///
+ [JsonConverter(typeof(PolymorphicJSONConverter))]
public interface ITextStreamTransform
{
///
diff --git a/LLama/Abstractions/ITextTransform.cs b/LLama/Abstractions/ITextTransform.cs
index 0bfeeb7f..f6f743f9 100644
--- a/LLama/Abstractions/ITextTransform.cs
+++ b/LLama/Abstractions/ITextTransform.cs
@@ -1,4 +1,7 @@
-namespace LLama.Abstractions
+using System.Text.Json.Serialization;
+using LLama.Common;
+
+namespace LLama.Abstractions
{
///
/// An interface for text transformations.
@@ -9,6 +12,7 @@
/// - Trimming
/// - etc.
///
+ [JsonConverter(typeof(PolymorphicJSONConverter))]
public interface ITextTransform
{
///
diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index b4117842..80298725 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -8,7 +8,6 @@ using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Common;
-using static LLama.Common.ChatHistory;
using static LLama.InteractiveExecutor;
using static LLama.LLamaContext;
using static LLama.StatefulExecutorBase;
@@ -20,9 +19,30 @@ namespace LLama;
///
public class ChatSession
{
- private const string _modelStateFilename = "ModelState.st";
- private const string _executorStateFilename = "ExecutorState.json";
- private const string _hsitoryFilename = "ChatHistory.json";
+ ///
+ /// The filename for the serialized model state (KV cache, etc).
+ ///
+ public const string MODEL_STATE_FILENAME = "ModelState.st";
+ ///
+ /// The filename for the serialized executor state.
+ ///
+ public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json";
+ ///
+ /// The filename for the serialized chat history.
+ ///
+ public const string HISTORY_STATE_FILENAME = "ChatHistory.json";
+ ///
+ /// The filename for the serialized input transform pipeline.
+ ///
+ public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json";
+ ///
+ /// The filename for the serialized output transform.
+ ///
+ public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json";
+ ///
+ /// The filename for the serialized history transform.
+ ///
+ public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json";
///
/// The executor for this session.
@@ -134,26 +154,7 @@ public class ChatSession
///
public void SaveSession(string path)
{
- if (string.IsNullOrWhiteSpace(path))
- {
- throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
- }
-
- if (Directory.Exists(path))
- {
- Directory.Delete(path, recursive: true);
- }
-
- Directory.CreateDirectory(path);
-
- string modelStateFilePath = Path.Combine(path, _modelStateFilename);
- Executor.Context.SaveState(modelStateFilePath);
-
- string executorStateFilepath = Path.Combine(path, _executorStateFilename);
- ((StatefulExecutorBase)Executor).SaveState(executorStateFilepath);
-
- string historyFilepath = Path.Combine(path, _hsitoryFilename);
- File.WriteAllText(historyFilepath, History.ToJson());
+ GetSessionState().Save(path);
}
///
@@ -202,26 +203,14 @@ public class ChatSession
///
public void LoadSession(string path)
{
- if (string.IsNullOrWhiteSpace(path))
+ var state = SessionState.Load(path);
+ // Handle non-polymorphic serialization of executor state
+ if (state.ExecutorState is ExecutorBaseState)
{
- throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
+ var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME);
+ ((StatefulExecutorBase) Executor).LoadState(filename: executorPath);
}
-
- if (!Directory.Exists(path))
- {
- throw new ArgumentException("Directory does not exist", nameof(path));
- }
-
- string modelStateFilePath = Path.Combine(path, _modelStateFilename);
- Executor.Context.LoadState(modelStateFilePath);
-
- string executorStateFilepath = Path.Combine(path, _executorStateFilename);
- ((StatefulExecutorBase)Executor).LoadState(executorStateFilepath);
-
- string historyFilepath = Path.Combine(path, _hsitoryFilename);
- string historyJson = File.ReadAllText(historyFilepath);
- History = ChatHistory.FromJson(historyJson)
- ?? throw new ArgumentException("History file is invalid", nameof(path));
+ LoadSession(state);
}
///
@@ -615,7 +604,7 @@ public record SessionState
///
/// The the chat history messages for this session.
///
- public Message[] History { get; set; } = Array.Empty();
+ public ChatHistory.Message[] History { get; set; } = Array.Empty();
///
/// Create a new session state.
@@ -638,4 +627,124 @@ public record SessionState
OutputTransform = outputTransform.Clone();
HistoryTransform = historyTransform.Clone();
}
+
+ ///
+ /// Save the session state to folder.
+ ///
+ ///
+ public void Save(string path)
+ {
+ if (string.IsNullOrWhiteSpace(path))
+ {
+ throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
+ }
+
+ if (Directory.Exists(path))
+ {
+ Directory.Delete(path, recursive: true);
+ }
+
+ Directory.CreateDirectory(path);
+
+ string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
+ var bytes = ContextState.ToByteArray();
+ File.WriteAllBytes(modelStateFilePath, bytes);
+
+ string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
+ File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState));
+
+ string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
+ File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson());
+
+ string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
+ File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline));
+
+ string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
+ File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform));
+
+ string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
+ File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform));
+ }
+
+ ///
+ /// Load the session state from folder.
+ ///
+ ///
+ ///
+ /// Throws when session state is incorrect
+ public static SessionState Load(string path)
+ {
+ if (string.IsNullOrWhiteSpace(path))
+ {
+ throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
+ }
+
+ if (!Directory.Exists(path))
+ {
+ throw new ArgumentException("Directory does not exist", nameof(path));
+ }
+
+ string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
+ var contextState = State.FromByteArray(File.ReadAllBytes(modelStateFilePath));
+
+ string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
+ var executorState = JsonSerializer.Deserialize(File.ReadAllText(executorStateFilepath))
+ ?? throw new ArgumentException("Executor state file is invalid", nameof(path));
+
+ string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
+ string historyJson = File.ReadAllText(historyFilepath);
+ var history = ChatHistory.FromJson(historyJson)
+ ?? throw new ArgumentException("History file is invalid", nameof(path));
+
+ string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
+ ITextTransform[] inputTransforms;
+ try
+ {
+ inputTransforms = File.Exists(inputTransformFilepath) ?
+ (JsonSerializer.Deserialize(File.ReadAllText(inputTransformFilepath))
+ ?? throw new ArgumentException("Input transform file is invalid", nameof(path)))
+ : Array.Empty();
+ }
+ catch (JsonException)
+ {
+ throw new ArgumentException("Input transform file is invalid", nameof(path));
+ }
+
+ string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
+
+ ITextStreamTransform outputTransform;
+ try
+ {
+ outputTransform = File.Exists(outputTransformFilepath) ?
+ (JsonSerializer.Deserialize(File.ReadAllText(outputTransformFilepath))
+ ?? throw new ArgumentException("Output transform file is invalid", nameof(path)))
+ : new LLamaTransforms.EmptyTextOutputStreamTransform();
+ }
+ catch (JsonException)
+ {
+ throw new ArgumentException("Output transform file is invalid", nameof(path));
+ }
+
+ string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
+ IHistoryTransform historyTransform;
+ try
+ {
+ historyTransform = File.Exists(historyTransformFilepath) ?
+ (JsonSerializer.Deserialize(File.ReadAllText(historyTransformFilepath))
+ ?? throw new ArgumentException("History transform file is invalid", nameof(path)))
+ : new LLamaTransforms.DefaultHistoryTransform();
+ }
+ catch (JsonException)
+ {
+ throw new ArgumentException("History transform file is invalid", nameof(path));
+ }
+
+ return new SessionState(
+ contextState,
+ executorState,
+ history,
+ inputTransforms.ToList(),
+ outputTransform,
+ historyTransform);
+ }
}
\ No newline at end of file
diff --git a/LLama/Common/PolymorphicJSONConverter.cs b/LLama/Common/PolymorphicJSONConverter.cs
new file mode 100644
index 00000000..6cec2f27
--- /dev/null
+++ b/LLama/Common/PolymorphicJSONConverter.cs
@@ -0,0 +1,57 @@
+using LLama.Abstractions;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Reflection;
+using System.Text;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace LLama.Common
+{
+ internal class PolymorphicJSONConverter : JsonConverter
+ {
+ public override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ if (reader.TokenType != JsonTokenType.StartObject)
+ throw new JsonException();
+ reader.Read();
+ if (reader.TokenType != JsonTokenType.PropertyName)
+ throw new JsonException();
+ string? propertyName = reader.GetString();
+ if (propertyName != "Name")
+ throw new JsonException();
+ reader.Read();
+ if (reader.TokenType != JsonTokenType.String)
+ throw new JsonException();
+ string? name = reader.GetString() ?? throw new JsonException();
+ var inheritedTypes = Assembly.GetExecutingAssembly().GetTypes().Where(
+ t => typeof(T).IsAssignableFrom(t) && !t.IsAbstract && !t.IsInterface
+ );
+ var type = inheritedTypes.FirstOrDefault(t => t.Name == name);
+ if (type == null)
+ throw new JsonException();
+ reader.Read();
+ if (reader.TokenType != JsonTokenType.PropertyName)
+ throw new JsonException();
+ propertyName = reader.GetString();
+ if (propertyName != "Data")
+ throw new JsonException();
+ var data = JsonSerializer.Deserialize(ref reader, type, options);
+ if (data == null)
+ throw new JsonException();
+ reader.Read();
+ reader.Read();
+ return (T)data;
+ }
+
+ public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
+ {
+ writer.WriteStartObject();
+ writer.WriteString("Name", value.GetType().Name);
+ writer.WritePropertyName("Data");
+ JsonSerializer.Serialize(writer, value, value.GetType(), options);
+ writer.WriteEndObject();
+ }
+ }
+}
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index d8b418c3..4a63be36 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -166,7 +166,7 @@ namespace LLama
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
// Wrap memory in a "state"
- var state = new State(memory);
+ var state = new State(memory, actualSize);
// Set memory to zero, to prevent it being freed in finally block
memory = IntPtr.Zero;
@@ -384,9 +384,12 @@ namespace LLama
public class State
: SafeLLamaHandleBase
{
- internal State(IntPtr memory)
+ private ulong _size;
+
+ internal State(IntPtr memory, ulong size)
: base(memory, true)
{
+ _size = size;
}
///
@@ -395,6 +398,29 @@ namespace LLama
Marshal.FreeHGlobal(handle);
return true;
}
+
+ ///
+ /// Convert this state to a byte array
+ ///
+ ///
+ public byte[] ToByteArray()
+ {
+ var bytes = new byte[_size];
+ Marshal.Copy(handle, bytes, 0, (int)_size);
+ return bytes;
+ }
+
+ ///
+ /// Load state from a byte array
+ ///
+ ///
+ ///
+ public static State FromByteArray(byte[] bytes)
+ {
+ var memory = Marshal.AllocHGlobal(bytes.Length);
+ Marshal.Copy(bytes, 0, memory, bytes.Length);
+ return new State(memory, (ulong)bytes.Length);
+ }
}
}
}
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index ea5616b5..ec72a25a 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -370,6 +370,7 @@ namespace LLama
public bool NeedToSaveSession { get; set; }
}
+ [JsonConverter(typeof(PolymorphicJSONConverter))]
public class ExecutorBaseState
{
[JsonPropertyName("n_past")]
diff --git a/LLama/LLamaTransforms.cs b/LLama/LLamaTransforms.cs
index 1ac0a79b..d74d9dda 100644
--- a/LLama/LLamaTransforms.cs
+++ b/LLama/LLamaTransforms.cs
@@ -3,6 +3,7 @@ using LLama.Common;
using System.Collections.Generic;
using System.Linq;
using System.Text;
+using System.Text.Json.Serialization;
namespace LLama
{
@@ -29,6 +30,12 @@ namespace LLama
private readonly string _unknownName;
private readonly bool _isInstructMode;
+ public string UserName => _userName;
+ public string AssistantName => _assistantName;
+ public string SystemName => _systemName;
+ public string UnknownName => _unknownName;
+ public bool IsInstructMode => _isInstructMode;
+
///
///
///
@@ -158,6 +165,42 @@ namespace LLama
private readonly int _maxKeywordLength;
private readonly bool _removeAllMatchedTokens;
+ ///
+ /// Keywords that you want to remove from the response.
+ /// This property is used for JSON serialization.
+ ///
+ [JsonPropertyName("keywords")]
+ public HashSet Keywords => _keywords;
+
+ ///
+ /// Maximum length of the keywords.
+ /// This property is used for JSON serialization.
+ ///
+ [JsonPropertyName("maxKeywordLength")]
+ public int MaxKeywordLength => _maxKeywordLength;
+
+ ///
+ /// If set to true, when getting a matched keyword, all the related tokens will be removed.
+ /// Otherwise only the part of keyword will be removed.
+ /// This property is used for JSON serialization.
+ ///
+ [JsonPropertyName("removeAllMatchedTokens")]
+ public bool RemoveAllMatchedTokens => _removeAllMatchedTokens;
+
+ ///
+ /// JSON constructor.
+ ///
+ [JsonConstructor]
+ public KeywordTextOutputStreamTransform(
+ HashSet keywords,
+ int maxKeywordLength,
+ bool removeAllMatchedTokens)
+ {
+ _keywords = new(keywords);
+ _maxKeywordLength = maxKeywordLength;
+ _removeAllMatchedTokens = removeAllMatchedTokens;
+ }
+
///
///
///