diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index 17908908..6a84d2fd 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -61,6 +61,12 @@ public class ChatSessionWithHistory Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session saved."); } + else if (userInput == "load") + { + session.LoadSession("Assets/chat-with-bob"); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Session loaded."); + } else if (userInput == "regenerate") { Console.ForegroundColor = ConsoleColor.Yellow; diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs index 3462c506..234bac3c 100644 --- a/LLama.Examples/Examples/ChatSessionWithRestart.cs +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -37,7 +37,8 @@ public class ChatSessionWithRestart }; Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("The chat session has started."); + Console.WriteLine("The chat session has started. Write `save` to save session in memory." + + " Write `reset` to start from the last saved checkpoint"); // show the prompt Console.ForegroundColor = ConsoleColor.Green; @@ -48,13 +49,13 @@ public class ChatSessionWithRestart if(userInput == "reset") { session.LoadSession(resetState); - Console.WriteLine($"History: {session.HistoryTransform.HistoryToText(session.History)}"); + Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.History)}"); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session reset."); } else if (userInput == "save") { - session.SaveSession("Assets/chat-with-bob"); + resetState = session.GetSessionState(); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session saved."); } diff --git a/LLama/Abstractions/IHistoryTransform.cs b/LLama/Abstractions/IHistoryTransform.cs index 5651343f..9644b3e1 100644 --- a/LLama/Abstractions/IHistoryTransform.cs +++ b/LLama/Abstractions/IHistoryTransform.cs @@ -1,10 +1,12 @@ using LLama.Common; +using System.Text.Json.Serialization; namespace LLama.Abstractions { /// /// Transform history to plain text and vice versa. /// + [JsonConverter(typeof(PolymorphicJSONConverter))] public interface IHistoryTransform { /// diff --git a/LLama/Abstractions/ITextStreamTransform.cs b/LLama/Abstractions/ITextStreamTransform.cs index 2b63299d..3ebdba67 100644 --- a/LLama/Abstractions/ITextStreamTransform.cs +++ b/LLama/Abstractions/ITextStreamTransform.cs @@ -1,10 +1,13 @@ -using System.Collections.Generic; +using LLama.Common; +using System.Collections.Generic; +using System.Text.Json.Serialization; namespace LLama.Abstractions { /// /// Takes a stream of tokens and transforms them. /// + [JsonConverter(typeof(PolymorphicJSONConverter))] public interface ITextStreamTransform { /// diff --git a/LLama/Abstractions/ITextTransform.cs b/LLama/Abstractions/ITextTransform.cs index 0bfeeb7f..f6f743f9 100644 --- a/LLama/Abstractions/ITextTransform.cs +++ b/LLama/Abstractions/ITextTransform.cs @@ -1,4 +1,7 @@ -namespace LLama.Abstractions +using System.Text.Json.Serialization; +using LLama.Common; + +namespace LLama.Abstractions { /// /// An interface for text transformations. @@ -9,6 +12,7 @@ /// - Trimming /// - etc. /// + [JsonConverter(typeof(PolymorphicJSONConverter))] public interface ITextTransform { /// diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index b4117842..80298725 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -8,7 +8,6 @@ using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; using LLama.Common; -using static LLama.Common.ChatHistory; using static LLama.InteractiveExecutor; using static LLama.LLamaContext; using static LLama.StatefulExecutorBase; @@ -20,9 +19,30 @@ namespace LLama; /// public class ChatSession { - private const string _modelStateFilename = "ModelState.st"; - private const string _executorStateFilename = "ExecutorState.json"; - private const string _hsitoryFilename = "ChatHistory.json"; + /// + /// The filename for the serialized model state (KV cache, etc). + /// + public const string MODEL_STATE_FILENAME = "ModelState.st"; + /// + /// The filename for the serialized executor state. + /// + public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json"; + /// + /// The filename for the serialized chat history. + /// + public const string HISTORY_STATE_FILENAME = "ChatHistory.json"; + /// + /// The filename for the serialized input transform pipeline. + /// + public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json"; + /// + /// The filename for the serialized output transform. + /// + public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json"; + /// + /// The filename for the serialized history transform. + /// + public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json"; /// /// The executor for this session. @@ -134,26 +154,7 @@ public class ChatSession /// public void SaveSession(string path) { - if (string.IsNullOrWhiteSpace(path)) - { - throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); - } - - if (Directory.Exists(path)) - { - Directory.Delete(path, recursive: true); - } - - Directory.CreateDirectory(path); - - string modelStateFilePath = Path.Combine(path, _modelStateFilename); - Executor.Context.SaveState(modelStateFilePath); - - string executorStateFilepath = Path.Combine(path, _executorStateFilename); - ((StatefulExecutorBase)Executor).SaveState(executorStateFilepath); - - string historyFilepath = Path.Combine(path, _hsitoryFilename); - File.WriteAllText(historyFilepath, History.ToJson()); + GetSessionState().Save(path); } /// @@ -202,26 +203,14 @@ public class ChatSession /// public void LoadSession(string path) { - if (string.IsNullOrWhiteSpace(path)) + var state = SessionState.Load(path); + // Handle non-polymorphic serialization of executor state + if (state.ExecutorState is ExecutorBaseState) { - throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME); + ((StatefulExecutorBase) Executor).LoadState(filename: executorPath); } - - if (!Directory.Exists(path)) - { - throw new ArgumentException("Directory does not exist", nameof(path)); - } - - string modelStateFilePath = Path.Combine(path, _modelStateFilename); - Executor.Context.LoadState(modelStateFilePath); - - string executorStateFilepath = Path.Combine(path, _executorStateFilename); - ((StatefulExecutorBase)Executor).LoadState(executorStateFilepath); - - string historyFilepath = Path.Combine(path, _hsitoryFilename); - string historyJson = File.ReadAllText(historyFilepath); - History = ChatHistory.FromJson(historyJson) - ?? throw new ArgumentException("History file is invalid", nameof(path)); + LoadSession(state); } /// @@ -615,7 +604,7 @@ public record SessionState /// /// The the chat history messages for this session. /// - public Message[] History { get; set; } = Array.Empty(); + public ChatHistory.Message[] History { get; set; } = Array.Empty(); /// /// Create a new session state. @@ -638,4 +627,124 @@ public record SessionState OutputTransform = outputTransform.Clone(); HistoryTransform = historyTransform.Clone(); } + + /// + /// Save the session state to folder. + /// + /// + public void Save(string path) + { + if (string.IsNullOrWhiteSpace(path)) + { + throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + } + + if (Directory.Exists(path)) + { + Directory.Delete(path, recursive: true); + } + + Directory.CreateDirectory(path); + + string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); + var bytes = ContextState.ToByteArray(); + File.WriteAllBytes(modelStateFilePath, bytes); + + string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); + File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState)); + + string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); + File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson()); + + string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); + File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline)); + + string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); + File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform)); + + string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME); + File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform)); + } + + /// + /// Load the session state from folder. + /// + /// + /// + /// Throws when session state is incorrect + public static SessionState Load(string path) + { + if (string.IsNullOrWhiteSpace(path)) + { + throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); + } + + if (!Directory.Exists(path)) + { + throw new ArgumentException("Directory does not exist", nameof(path)); + } + + string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); + var contextState = State.FromByteArray(File.ReadAllBytes(modelStateFilePath)); + + string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); + var executorState = JsonSerializer.Deserialize(File.ReadAllText(executorStateFilepath)) + ?? throw new ArgumentException("Executor state file is invalid", nameof(path)); + + string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); + string historyJson = File.ReadAllText(historyFilepath); + var history = ChatHistory.FromJson(historyJson) + ?? throw new ArgumentException("History file is invalid", nameof(path)); + + string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); + ITextTransform[] inputTransforms; + try + { + inputTransforms = File.Exists(inputTransformFilepath) ? + (JsonSerializer.Deserialize(File.ReadAllText(inputTransformFilepath)) + ?? throw new ArgumentException("Input transform file is invalid", nameof(path))) + : Array.Empty(); + } + catch (JsonException) + { + throw new ArgumentException("Input transform file is invalid", nameof(path)); + } + + string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); + + ITextStreamTransform outputTransform; + try + { + outputTransform = File.Exists(outputTransformFilepath) ? + (JsonSerializer.Deserialize(File.ReadAllText(outputTransformFilepath)) + ?? throw new ArgumentException("Output transform file is invalid", nameof(path))) + : new LLamaTransforms.EmptyTextOutputStreamTransform(); + } + catch (JsonException) + { + throw new ArgumentException("Output transform file is invalid", nameof(path)); + } + + string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME); + IHistoryTransform historyTransform; + try + { + historyTransform = File.Exists(historyTransformFilepath) ? + (JsonSerializer.Deserialize(File.ReadAllText(historyTransformFilepath)) + ?? throw new ArgumentException("History transform file is invalid", nameof(path))) + : new LLamaTransforms.DefaultHistoryTransform(); + } + catch (JsonException) + { + throw new ArgumentException("History transform file is invalid", nameof(path)); + } + + return new SessionState( + contextState, + executorState, + history, + inputTransforms.ToList(), + outputTransform, + historyTransform); + } } \ No newline at end of file diff --git a/LLama/Common/PolymorphicJSONConverter.cs b/LLama/Common/PolymorphicJSONConverter.cs new file mode 100644 index 00000000..6cec2f27 --- /dev/null +++ b/LLama/Common/PolymorphicJSONConverter.cs @@ -0,0 +1,57 @@ +using LLama.Abstractions; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace LLama.Common +{ + internal class PolymorphicJSONConverter : JsonConverter + { + public override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + throw new JsonException(); + reader.Read(); + if (reader.TokenType != JsonTokenType.PropertyName) + throw new JsonException(); + string? propertyName = reader.GetString(); + if (propertyName != "Name") + throw new JsonException(); + reader.Read(); + if (reader.TokenType != JsonTokenType.String) + throw new JsonException(); + string? name = reader.GetString() ?? throw new JsonException(); + var inheritedTypes = Assembly.GetExecutingAssembly().GetTypes().Where( + t => typeof(T).IsAssignableFrom(t) && !t.IsAbstract && !t.IsInterface + ); + var type = inheritedTypes.FirstOrDefault(t => t.Name == name); + if (type == null) + throw new JsonException(); + reader.Read(); + if (reader.TokenType != JsonTokenType.PropertyName) + throw new JsonException(); + propertyName = reader.GetString(); + if (propertyName != "Data") + throw new JsonException(); + var data = JsonSerializer.Deserialize(ref reader, type, options); + if (data == null) + throw new JsonException(); + reader.Read(); + reader.Read(); + return (T)data; + } + + public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + writer.WriteString("Name", value.GetType().Name); + writer.WritePropertyName("Data"); + JsonSerializer.Serialize(writer, value, value.GetType(), options); + writer.WriteEndObject(); + } + } +} diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index d8b418c3..4a63be36 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -166,7 +166,7 @@ namespace LLama memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); // Wrap memory in a "state" - var state = new State(memory); + var state = new State(memory, actualSize); // Set memory to zero, to prevent it being freed in finally block memory = IntPtr.Zero; @@ -384,9 +384,12 @@ namespace LLama public class State : SafeLLamaHandleBase { - internal State(IntPtr memory) + private ulong _size; + + internal State(IntPtr memory, ulong size) : base(memory, true) { + _size = size; } /// @@ -395,6 +398,29 @@ namespace LLama Marshal.FreeHGlobal(handle); return true; } + + /// + /// Convert this state to a byte array + /// + /// + public byte[] ToByteArray() + { + var bytes = new byte[_size]; + Marshal.Copy(handle, bytes, 0, (int)_size); + return bytes; + } + + /// + /// Load state from a byte array + /// + /// + /// + public static State FromByteArray(byte[] bytes) + { + var memory = Marshal.AllocHGlobal(bytes.Length); + Marshal.Copy(bytes, 0, memory, bytes.Length); + return new State(memory, (ulong)bytes.Length); + } } } } diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index ea5616b5..ec72a25a 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -370,6 +370,7 @@ namespace LLama public bool NeedToSaveSession { get; set; } } + [JsonConverter(typeof(PolymorphicJSONConverter))] public class ExecutorBaseState { [JsonPropertyName("n_past")] diff --git a/LLama/LLamaTransforms.cs b/LLama/LLamaTransforms.cs index 1ac0a79b..d74d9dda 100644 --- a/LLama/LLamaTransforms.cs +++ b/LLama/LLamaTransforms.cs @@ -3,6 +3,7 @@ using LLama.Common; using System.Collections.Generic; using System.Linq; using System.Text; +using System.Text.Json.Serialization; namespace LLama { @@ -29,6 +30,12 @@ namespace LLama private readonly string _unknownName; private readonly bool _isInstructMode; + public string UserName => _userName; + public string AssistantName => _assistantName; + public string SystemName => _systemName; + public string UnknownName => _unknownName; + public bool IsInstructMode => _isInstructMode; + /// /// /// @@ -158,6 +165,42 @@ namespace LLama private readonly int _maxKeywordLength; private readonly bool _removeAllMatchedTokens; + /// + /// Keywords that you want to remove from the response. + /// This property is used for JSON serialization. + /// + [JsonPropertyName("keywords")] + public HashSet Keywords => _keywords; + + /// + /// Maximum length of the keywords. + /// This property is used for JSON serialization. + /// + [JsonPropertyName("maxKeywordLength")] + public int MaxKeywordLength => _maxKeywordLength; + + /// + /// If set to true, when getting a matched keyword, all the related tokens will be removed. + /// Otherwise only the part of keyword will be removed. + /// This property is used for JSON serialization. + /// + [JsonPropertyName("removeAllMatchedTokens")] + public bool RemoveAllMatchedTokens => _removeAllMatchedTokens; + + /// + /// JSON constructor. + /// + [JsonConstructor] + public KeywordTextOutputStreamTransform( + HashSet keywords, + int maxKeywordLength, + bool removeAllMatchedTokens) + { + _keywords = new(keywords); + _maxKeywordLength = maxKeywordLength; + _removeAllMatchedTokens = removeAllMatchedTokens; + } + /// /// ///