783 lines
28 KiB
C#
783 lines
28 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.IO;
|
|
using System.Linq;
|
|
using System.Runtime.CompilerServices;
|
|
using System.Text.Json;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
using LLama.Abstractions;
|
|
using LLama.Common;
|
|
using static LLama.InteractiveExecutor;
|
|
using static LLama.LLamaContext;
|
|
using static LLama.StatefulExecutorBase;
|
|
|
|
namespace LLama;
|
|
|
|
/// <summary>
|
|
/// The main chat session class.
|
|
/// </summary>
|
|
public class ChatSession
|
|
{
|
|
/// <summary>
|
|
/// The filename for the serialized model state (KV cache, etc).
|
|
/// </summary>
|
|
public const string MODEL_STATE_FILENAME = "ModelState.st";
|
|
/// <summary>
|
|
/// The filename for the serialized executor state.
|
|
/// </summary>
|
|
public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json";
|
|
/// <summary>
|
|
/// The filename for the serialized chat history.
|
|
/// </summary>
|
|
public const string HISTORY_STATE_FILENAME = "ChatHistory.json";
|
|
/// <summary>
|
|
/// The filename for the serialized input transform pipeline.
|
|
/// </summary>
|
|
public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json";
|
|
/// <summary>
|
|
/// The filename for the serialized output transform.
|
|
/// </summary>
|
|
public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json";
|
|
/// <summary>
|
|
/// The filename for the serialized history transform.
|
|
/// </summary>
|
|
public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json";
|
|
|
|
/// <summary>
|
|
/// The executor for this session.
|
|
/// </summary>
|
|
public ILLamaExecutor Executor { get; private set; }
|
|
|
|
/// <summary>
|
|
/// The chat history for this session.
|
|
/// </summary>
|
|
public ChatHistory History { get; private set; } = new();
|
|
|
|
/// <summary>
|
|
/// The history transform used in this session.
|
|
/// </summary>
|
|
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
|
|
|
|
/// <summary>
|
|
/// The input transform pipeline used in this session.
|
|
/// </summary>
|
|
public List<ITextTransform> InputTransformPipeline { get; set; } = new();
|
|
|
|
/// <summary>
|
|
/// The output transform used in this session.
|
|
/// </summary>
|
|
public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
|
|
|
|
/// <summary>
|
|
/// Create a new chat session and preprocess history.
|
|
/// </summary>
|
|
/// <param name="executor">The executor for this session</param>
|
|
/// <param name="history">History for this session</param>
|
|
/// <returns></returns>
|
|
public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
|
|
ILLamaExecutor executor, ChatHistory history)
|
|
{
|
|
if (executor is not StatefulExecutorBase statefulExecutor)
|
|
{
|
|
throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
|
|
}
|
|
var session = new ChatSession(executor, history);
|
|
await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history));
|
|
return session;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Create a new chat session.
|
|
/// </summary>
|
|
/// <param name="executor">The executor for this session</param>
|
|
public ChatSession(ILLamaExecutor executor)
|
|
{
|
|
// Check if executor has StatefulExecutorBase as base class
|
|
if (executor is not StatefulExecutorBase)
|
|
{
|
|
throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
|
|
}
|
|
|
|
Executor = executor;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Create a new chat session with a custom history.
|
|
/// </summary>
|
|
/// <param name="executor"></param>
|
|
/// <param name="history"></param>
|
|
public ChatSession(ILLamaExecutor executor, ChatHistory history)
|
|
: this(executor)
|
|
{
|
|
History = history;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Use a custom history transform.
|
|
/// </summary>
|
|
/// <param name="transform"></param>
|
|
/// <returns></returns>
|
|
public ChatSession WithHistoryTransform(IHistoryTransform transform)
|
|
{
|
|
HistoryTransform = transform;
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Add a text transform to the input transform pipeline.
|
|
/// </summary>
|
|
/// <param name="transform"></param>
|
|
/// <returns></returns>
|
|
public ChatSession AddInputTransform(ITextTransform transform)
|
|
{
|
|
InputTransformPipeline.Add(transform);
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Use a custom output transform.
|
|
/// </summary>
|
|
/// <param name="transform"></param>
|
|
/// <returns></returns>
|
|
public ChatSession WithOutputTransform(ITextStreamTransform transform)
|
|
{
|
|
OutputTransform = transform;
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Save a session from a directory.
|
|
/// </summary>
|
|
/// <param name="path"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ArgumentException"></exception>
|
|
public void SaveSession(string path)
|
|
{
|
|
GetSessionState().Save(path);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Get the session state.
|
|
/// </summary>
|
|
/// <returns>SessionState object representing session state in-memory</returns>
|
|
public SessionState GetSessionState()
|
|
{
|
|
var executorState = ((StatefulExecutorBase)Executor).GetStateData();
|
|
return new SessionState(
|
|
executorState.PastTokensCount > 0
|
|
? Executor.Context.GetState() : null,
|
|
executorState,
|
|
History,
|
|
InputTransformPipeline,
|
|
OutputTransform,
|
|
HistoryTransform);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Load a session from a session state.
|
|
/// </summary>
|
|
/// <param name="state"></param>
|
|
/// <param name="loadTransforms">If true loads transforms saved in the session state.</param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ArgumentException"></exception>
|
|
public void LoadSession(SessionState state, bool loadTransforms = true)
|
|
{
|
|
if (Executor is StatefulExecutorBase statefulExecutor)
|
|
{
|
|
if (state.ExecutorState is not null)
|
|
{
|
|
statefulExecutor.LoadState(state.ExecutorState);
|
|
}
|
|
}
|
|
if (state.ContextState is null)
|
|
{
|
|
Executor.Context.NativeHandle.KvCacheClear();
|
|
}
|
|
else
|
|
{
|
|
Executor.Context.LoadState(state.ContextState);
|
|
}
|
|
History = new ChatHistory(state.History);
|
|
if (loadTransforms)
|
|
{
|
|
InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList();
|
|
OutputTransform = state.OutputTransform.Clone();
|
|
HistoryTransform = state.HistoryTransform.Clone();
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Load a session from a directory.
|
|
/// </summary>
|
|
/// <param name="path"></param>
|
|
/// <param name="loadTransforms">If true loads transforms saved in the session state.</param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ArgumentException"></exception>
|
|
public void LoadSession(string path, bool loadTransforms = true)
|
|
{
|
|
var state = SessionState.Load(path);
|
|
// Handle non-polymorphic serialization of executor state
|
|
if (state.ExecutorState is null)
|
|
{
|
|
var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME);
|
|
((StatefulExecutorBase) Executor).LoadState(filename: executorPath);
|
|
}
|
|
LoadSession(state, loadTransforms);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Add a message to the chat history.
|
|
/// </summary>
|
|
/// <param name="message"></param>
|
|
/// <returns></returns>
|
|
public ChatSession AddMessage(ChatHistory.Message message)
|
|
{
|
|
// If current message is a system message, only allow the history to be empty
|
|
if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0)
|
|
{
|
|
throw new ArgumentException("Cannot add a system message after another message", nameof(message));
|
|
}
|
|
|
|
// If current message is a user message, only allow the history to be empty,
|
|
// or the previous message to be a system message or assistant message.
|
|
if (message.AuthorRole == AuthorRole.User)
|
|
{
|
|
ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
|
|
if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User)
|
|
{
|
|
throw new ArgumentException("Cannot add a user message after another user message", nameof(message));
|
|
}
|
|
}
|
|
|
|
// If the current message is an assistant message,
|
|
// the previous message must be a user message.
|
|
if (message.AuthorRole == AuthorRole.Assistant)
|
|
{
|
|
ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
|
|
if (lastMessage is null
|
|
|| lastMessage.AuthorRole != AuthorRole.User)
|
|
{
|
|
throw new ArgumentException("Assistant message must be preceded with a user message", nameof(message));
|
|
}
|
|
}
|
|
|
|
History.AddMessage(message.AuthorRole, message.Content);
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Add a system message to the chat history.
|
|
/// </summary>
|
|
/// <param name="content"></param>
|
|
/// <returns></returns>
|
|
public ChatSession AddSystemMessage(string content)
|
|
=> AddMessage(new ChatHistory.Message(AuthorRole.System, content));
|
|
|
|
/// <summary>
|
|
/// Add an assistant message to the chat history.
|
|
/// </summary>
|
|
/// <param name="content"></param>
|
|
/// <returns></returns>
|
|
public ChatSession AddAssistantMessage(string content)
|
|
=> AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
|
|
|
|
/// <summary>
|
|
/// Add a user message to the chat history.
|
|
/// </summary>
|
|
/// <param name="content"></param>
|
|
/// <returns></returns>
|
|
public ChatSession AddUserMessage(string content)
|
|
=> AddMessage(new ChatHistory.Message(AuthorRole.User, content));
|
|
|
|
/// <summary>
|
|
/// Remove the last message from the chat history.
|
|
/// </summary>
|
|
/// <returns></returns>
|
|
public ChatSession RemoveLastMessage()
|
|
{
|
|
History.Messages.RemoveAt(History.Messages.Count - 1);
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compute KV cache for the message and add it to the chat history.
|
|
/// </summary>
|
|
/// <param name="message"></param>
|
|
/// <returns></returns>
|
|
public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message)
|
|
{
|
|
if (Executor is not StatefulExecutorBase statefulExecutor)
|
|
{
|
|
throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
|
|
}
|
|
AddMessage(message);
|
|
var content = message.Content;
|
|
if (message.AuthorRole != AuthorRole.Assistant)
|
|
{
|
|
foreach (var inputTransform in InputTransformPipeline)
|
|
{
|
|
content = inputTransform.Transform(content);
|
|
}
|
|
}
|
|
|
|
await statefulExecutor.PrefillPromptAsync(content);
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Compute KV cache for the system message and add it to the chat history.
|
|
/// </summary>
|
|
public Task<ChatSession> AddAndProcessSystemMessage(string content)
|
|
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content));
|
|
|
|
/// <summary>
|
|
/// Compute KV cache for the user message and add it to the chat history.
|
|
/// </summary>
|
|
public Task<ChatSession> AddAndProcessUserMessage(string content)
|
|
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content));
|
|
|
|
/// <summary>
|
|
/// Compute KV cache for the assistant message and add it to the chat history.
|
|
/// </summary>
|
|
public Task<ChatSession> AddAndProcessAssistantMessage(string content)
|
|
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
|
|
|
|
/// <summary>
|
|
/// Replace a user message with a new message and remove all messages after the new message.
|
|
/// This is useful when the user wants to edit a message. And regenerate the response.
|
|
/// </summary>
|
|
/// <param name="oldMessage"></param>
|
|
/// <param name="newMessage"></param>
|
|
/// <returns></returns>
|
|
public ChatSession ReplaceUserMessage(
|
|
ChatHistory.Message oldMessage,
|
|
ChatHistory.Message newMessage)
|
|
{
|
|
if (oldMessage.AuthorRole != AuthorRole.User)
|
|
{
|
|
throw new ArgumentException("Old message must be a user message", nameof(oldMessage));
|
|
}
|
|
|
|
if (newMessage.AuthorRole != AuthorRole.User)
|
|
{
|
|
throw new ArgumentException("New message must be a user message", nameof(newMessage));
|
|
}
|
|
|
|
int index = History.Messages.IndexOf(oldMessage);
|
|
if (index == -1)
|
|
{
|
|
throw new ArgumentException("Old message does not exist in history", nameof(oldMessage));
|
|
}
|
|
|
|
History.Messages[index] = newMessage;
|
|
|
|
// Remove all message after the new message
|
|
History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1);
|
|
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Chat with the model.
|
|
/// </summary>
|
|
/// <param name="message"></param>
|
|
/// <param name="inferenceParams"></param>
|
|
/// <param name="applyInputTransformPipeline"></param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ArgumentException"></exception>
|
|
public async IAsyncEnumerable<string> ChatAsync(
|
|
ChatHistory.Message message,
|
|
bool applyInputTransformPipeline,
|
|
IInferenceParams? inferenceParams = null,
|
|
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
|
{
|
|
// The message must be a user message
|
|
if (message.AuthorRole != AuthorRole.User)
|
|
{
|
|
throw new ArgumentException("Message must be a user message", nameof(message));
|
|
}
|
|
|
|
// Apply input transform pipeline
|
|
if (applyInputTransformPipeline)
|
|
{
|
|
foreach (var inputTransform in InputTransformPipeline)
|
|
{
|
|
message.Content = inputTransform.Transform(message.Content);
|
|
}
|
|
}
|
|
|
|
// Add the user's message to the history
|
|
AddUserMessage(message.Content);
|
|
|
|
// Prepare prompt variable
|
|
string prompt;
|
|
|
|
// Check if the session history was restored from a previous session
|
|
// or added as part of new chat session history.
|
|
InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData();
|
|
|
|
// If "IsPromptRun" is true, the session was newly started.
|
|
if (state.IsPromptRun)
|
|
{
|
|
// If the session history was added as part of new chat session history,
|
|
// convert the complete history includsing system message and manually added history
|
|
// to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
|
|
prompt = HistoryTransform.HistoryToText(History);
|
|
}
|
|
else
|
|
{
|
|
// If the session was restored from a previous session,
|
|
// convert only the current message to the prompt with the prompt template
|
|
// specified in the HistoryTransform class implementation that is provided.
|
|
ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content);
|
|
prompt = HistoryTransform.HistoryToText(singleMessageHistory);
|
|
}
|
|
|
|
string assistantMessage = string.Empty;
|
|
|
|
await foreach (
|
|
string textToken
|
|
in ChatAsyncInternal(
|
|
prompt,
|
|
inferenceParams,
|
|
cancellationToken))
|
|
{
|
|
assistantMessage += textToken;
|
|
yield return textToken;
|
|
}
|
|
|
|
// Add the assistant message to the history
|
|
AddAssistantMessage(assistantMessage);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Chat with the model.
|
|
/// </summary>
|
|
/// <param name="message"></param>
|
|
/// <param name="inferenceParams"></param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
public IAsyncEnumerable<string> ChatAsync(
|
|
ChatHistory.Message message,
|
|
IInferenceParams? inferenceParams = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
return ChatAsync(
|
|
message,
|
|
applyInputTransformPipeline: true,
|
|
inferenceParams,
|
|
cancellationToken);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Chat with the model.
|
|
/// </summary>
|
|
/// <param name="history"></param>
|
|
/// <param name="applyInputTransformPipeline"></param>
|
|
/// <param name="inferenceParams"></param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ArgumentException"></exception>
|
|
public IAsyncEnumerable<string> ChatAsync(
|
|
ChatHistory history,
|
|
bool applyInputTransformPipeline,
|
|
IInferenceParams? inferenceParams = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
ChatHistory.Message lastMessage = history.Messages.LastOrDefault()
|
|
?? throw new ArgumentException("History must contain at least one message", nameof(history));
|
|
|
|
foreach (
|
|
ChatHistory.Message message
|
|
in history.Messages.Take(history.Messages.Count - 1))
|
|
{
|
|
// Apply input transform pipeline
|
|
if (applyInputTransformPipeline
|
|
&& message.AuthorRole == AuthorRole.User)
|
|
{
|
|
foreach (
|
|
var inputTransform
|
|
in InputTransformPipeline)
|
|
{
|
|
message.Content = inputTransform.Transform(message.Content);
|
|
}
|
|
}
|
|
|
|
AddMessage(message);
|
|
}
|
|
|
|
return ChatAsync(
|
|
lastMessage,
|
|
applyInputTransformPipeline,
|
|
inferenceParams,
|
|
cancellationToken);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Chat with the model.
|
|
/// </summary>
|
|
/// <param name="history"></param>
|
|
/// <param name="inferenceParams"></param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
public IAsyncEnumerable<string> ChatAsync(
|
|
ChatHistory history,
|
|
IInferenceParams? inferenceParams = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
return ChatAsync(
|
|
history,
|
|
applyInputTransformPipeline: true,
|
|
inferenceParams,
|
|
cancellationToken);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Regenerate the last assistant message.
|
|
/// </summary>
|
|
/// <param name="inferenceParams"></param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="InvalidOperationException"></exception>
|
|
public async IAsyncEnumerable<string> RegenerateAssistantMessageAsync(
|
|
InferenceParams? inferenceParams = null,
|
|
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
|
{
|
|
// Make sure the last message is an assistant message (reponse from the LLM).
|
|
ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault();
|
|
|
|
if (lastAssistantMessage is null
|
|
|| lastAssistantMessage.AuthorRole != AuthorRole.Assistant)
|
|
{
|
|
throw new InvalidOperationException("Last message must be an assistant message");
|
|
}
|
|
|
|
// Remove the last assistant message from the history.
|
|
RemoveLastMessage();
|
|
|
|
// Get the last user message.
|
|
ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault();
|
|
|
|
if (lastUserMessage is null
|
|
|| lastUserMessage.AuthorRole != AuthorRole.User)
|
|
{
|
|
throw new InvalidOperationException("Last message must be a user message");
|
|
}
|
|
|
|
// Remove the last user message from the history.
|
|
RemoveLastMessage();
|
|
|
|
// Regenerate the assistant message.
|
|
await foreach (
|
|
string textToken
|
|
in ChatAsync(
|
|
lastUserMessage,
|
|
applyInputTransformPipeline: false,
|
|
inferenceParams,
|
|
cancellationToken))
|
|
{
|
|
yield return textToken;
|
|
}
|
|
}
|
|
|
|
private async IAsyncEnumerable<string> ChatAsyncInternal(
|
|
string prompt,
|
|
IInferenceParams? inferenceParams = null,
|
|
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
|
{
|
|
var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken);
|
|
|
|
await foreach (
|
|
string textToken
|
|
in OutputTransform
|
|
.TransformAsync(results)
|
|
.WithCancellation(cancellationToken))
|
|
{
|
|
yield return textToken;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// The state of a chat session in-memory.
|
|
/// </summary>
|
|
public record SessionState
|
|
{
|
|
/// <summary>
|
|
/// Saved executor state for the session in JSON format.
|
|
/// </summary>
|
|
public ExecutorBaseState? ExecutorState { get; set; }
|
|
|
|
/// <summary>
|
|
/// Saved context state (KV cache) for the session.
|
|
/// </summary>
|
|
public State? ContextState { get; set; }
|
|
|
|
/// <summary>
|
|
/// The input transform pipeline used in this session.
|
|
/// </summary>
|
|
public ITextTransform[] InputTransformPipeline { get; set; } = Array.Empty<ITextTransform>();
|
|
|
|
/// <summary>
|
|
/// The output transform used in this session.
|
|
/// </summary>
|
|
public ITextStreamTransform OutputTransform { get; set; } = new LLamaTransforms.EmptyTextOutputStreamTransform();
|
|
|
|
/// <summary>
|
|
/// The history transform used in this session.
|
|
/// </summary>
|
|
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
|
|
|
|
/// <summary>
|
|
/// The the chat history messages for this session.
|
|
/// </summary>
|
|
public ChatHistory.Message[] History { get; set; } = Array.Empty<ChatHistory.Message>();
|
|
|
|
/// <summary>
|
|
/// Create a new session state.
|
|
/// </summary>
|
|
/// <param name="contextState"></param>
|
|
/// <param name="executorState"></param>
|
|
/// <param name="history"></param>
|
|
/// <param name="inputTransformPipeline"></param>
|
|
/// <param name="outputTransform"></param>
|
|
/// <param name="historyTransform"></param>
|
|
public SessionState(
|
|
State? contextState, ExecutorBaseState executorState,
|
|
ChatHistory history, List<ITextTransform> inputTransformPipeline,
|
|
ITextStreamTransform outputTransform, IHistoryTransform historyTransform)
|
|
{
|
|
ContextState = contextState;
|
|
ExecutorState = executorState;
|
|
History = history.Messages.ToArray();
|
|
InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray();
|
|
OutputTransform = outputTransform.Clone();
|
|
HistoryTransform = historyTransform.Clone();
|
|
}
|
|
|
|
/// <summary>
|
|
/// Save the session state to folder.
|
|
/// </summary>
|
|
/// <param name="path"></param>
|
|
public void Save(string path)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(path))
|
|
{
|
|
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
|
|
}
|
|
|
|
if (Directory.Exists(path))
|
|
{
|
|
Directory.Delete(path, recursive: true);
|
|
}
|
|
|
|
Directory.CreateDirectory(path);
|
|
|
|
string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
|
|
var bytes = ContextState?.ToByteArray();
|
|
if (bytes is not null)
|
|
{
|
|
File.WriteAllBytes(modelStateFilePath, bytes);
|
|
}
|
|
|
|
string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
|
|
File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState));
|
|
|
|
string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
|
|
File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson());
|
|
|
|
string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
|
|
File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline));
|
|
|
|
string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
|
|
File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform));
|
|
|
|
string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
|
|
File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform));
|
|
}
|
|
|
|
/// <summary>
|
|
/// Load the session state from folder.
|
|
/// </summary>
|
|
/// <param name="path"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ArgumentException">Throws when session state is incorrect</exception>
|
|
public static SessionState Load(string path)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(path))
|
|
{
|
|
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
|
|
}
|
|
|
|
if (!Directory.Exists(path))
|
|
{
|
|
throw new ArgumentException("Directory does not exist", nameof(path));
|
|
}
|
|
|
|
string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
|
|
var contextState = File.Exists(modelStateFilePath) ?
|
|
State.FromByteArray(File.ReadAllBytes(modelStateFilePath))
|
|
: null;
|
|
|
|
string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
|
|
var executorState = JsonSerializer.Deserialize<ExecutorBaseState>(File.ReadAllText(executorStateFilepath));
|
|
|
|
string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
|
|
string historyJson = File.ReadAllText(historyFilepath);
|
|
var history = ChatHistory.FromJson(historyJson)
|
|
?? throw new ArgumentException("History file is invalid", nameof(path));
|
|
|
|
string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
|
|
ITextTransform[] inputTransforms;
|
|
try
|
|
{
|
|
inputTransforms = File.Exists(inputTransformFilepath) ?
|
|
(JsonSerializer.Deserialize<ITextTransform[]>(File.ReadAllText(inputTransformFilepath))
|
|
?? throw new ArgumentException("Input transform file is invalid", nameof(path)))
|
|
: Array.Empty<ITextTransform>();
|
|
}
|
|
catch (JsonException)
|
|
{
|
|
throw new ArgumentException("Input transform file is invalid", nameof(path));
|
|
}
|
|
|
|
string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
|
|
|
|
ITextStreamTransform outputTransform;
|
|
try
|
|
{
|
|
outputTransform = File.Exists(outputTransformFilepath) ?
|
|
(JsonSerializer.Deserialize<ITextStreamTransform>(File.ReadAllText(outputTransformFilepath))
|
|
?? throw new ArgumentException("Output transform file is invalid", nameof(path)))
|
|
: new LLamaTransforms.EmptyTextOutputStreamTransform();
|
|
}
|
|
catch (JsonException)
|
|
{
|
|
throw new ArgumentException("Output transform file is invalid", nameof(path));
|
|
}
|
|
|
|
string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
|
|
IHistoryTransform historyTransform;
|
|
try
|
|
{
|
|
historyTransform = File.Exists(historyTransformFilepath) ?
|
|
(JsonSerializer.Deserialize<IHistoryTransform>(File.ReadAllText(historyTransformFilepath))
|
|
?? throw new ArgumentException("History transform file is invalid", nameof(path)))
|
|
: new LLamaTransforms.DefaultHistoryTransform();
|
|
}
|
|
catch (JsonException)
|
|
{
|
|
throw new ArgumentException("History transform file is invalid", nameof(path));
|
|
}
|
|
|
|
return new SessionState(
|
|
contextState,
|
|
executorState,
|
|
history,
|
|
inputTransforms.ToList(),
|
|
outputTransform,
|
|
historyTransform);
|
|
}
|
|
} |