497 lines
17 KiB
C#
497 lines
17 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.IO;
|
|
using System.Linq;
|
|
using System.Runtime.CompilerServices;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
using LLama.Abstractions;
|
|
using LLama.Common;
|
|
using static LLama.InteractiveExecutor;
|
|
|
|
namespace LLama;
|
|
|
|
/// <summary>
|
|
/// The main chat session class.
|
|
/// </summary>
|
|
public class ChatSession
|
|
{
|
|
private const string _modelStateFilename = "ModelState.st";
|
|
private const string _executorStateFilename = "ExecutorState.json";
|
|
private const string _hsitoryFilename = "ChatHistory.json";
|
|
|
|
/// <summary>
|
|
/// The executor for this session.
|
|
/// </summary>
|
|
public ILLamaExecutor Executor { get; private set; }
|
|
|
|
/// <summary>
|
|
/// The chat history for this session.
|
|
/// </summary>
|
|
public ChatHistory History { get; private set; } = new();
|
|
|
|
/// <summary>
|
|
/// The history transform used in this session.
|
|
/// </summary>
|
|
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
|
|
|
|
/// <summary>
|
|
/// The input transform pipeline used in this session.
|
|
/// </summary>
|
|
public List<ITextTransform> InputTransformPipeline { get; set; } = new();
|
|
|
|
/// <summary>
|
|
/// The output transform used in this session.
|
|
/// </summary>
|
|
public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
|
|
|
|
/// <summary>
|
|
/// Create a new chat session.
|
|
/// </summary>
|
|
/// <param name="executor">The executor for this session</param>
|
|
public ChatSession(ILLamaExecutor executor)
|
|
{
|
|
// Check if executor has StatefulExecutorBase as base class
|
|
if (executor is not StatefulExecutorBase)
|
|
{
|
|
throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
|
|
}
|
|
|
|
Executor = executor;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Create a new chat session with a custom history.
|
|
/// </summary>
|
|
/// <param name="executor"></param>
|
|
/// <param name="history"></param>
|
|
public ChatSession(ILLamaExecutor executor, ChatHistory history)
|
|
: this(executor)
|
|
{
|
|
History = history;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Use a custom history transform.
|
|
/// </summary>
|
|
/// <param name="transform"></param>
|
|
/// <returns></returns>
|
|
public ChatSession WithHistoryTransform(IHistoryTransform transform)
|
|
{
|
|
HistoryTransform = transform;
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Add a text transform to the input transform pipeline.
|
|
/// </summary>
|
|
/// <param name="transform"></param>
|
|
/// <returns></returns>
|
|
public ChatSession AddInputTransform(ITextTransform transform)
|
|
{
|
|
InputTransformPipeline.Add(transform);
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Use a custom output transform.
|
|
/// </summary>
|
|
/// <param name="transform"></param>
|
|
/// <returns></returns>
|
|
public ChatSession WithOutputTransform(ITextStreamTransform transform)
|
|
{
|
|
OutputTransform = transform;
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Save a session from a directory.
|
|
/// </summary>
|
|
/// <param name="path"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ArgumentException"></exception>
|
|
public void SaveSession(string path)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(path))
|
|
{
|
|
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
|
|
}
|
|
|
|
if (Directory.Exists(path))
|
|
{
|
|
Directory.Delete(path, recursive: true);
|
|
}
|
|
|
|
Directory.CreateDirectory(path);
|
|
|
|
string modelStateFilePath = Path.Combine(path, _modelStateFilename);
|
|
Executor.Context.SaveState(modelStateFilePath);
|
|
|
|
string executorStateFilepath = Path.Combine(path, _executorStateFilename);
|
|
((StatefulExecutorBase)Executor).SaveState(executorStateFilepath);
|
|
|
|
string historyFilepath = Path.Combine(path, _hsitoryFilename);
|
|
File.WriteAllText(historyFilepath, History.ToJson());
|
|
}
|
|
|
|
/// <summary>
|
|
/// Load a session from a directory.
|
|
/// </summary>
|
|
/// <param name="path"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ArgumentException"></exception>
|
|
public void LoadSession(string path)
|
|
{
|
|
if (string.IsNullOrWhiteSpace(path))
|
|
{
|
|
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
|
|
}
|
|
|
|
if (!Directory.Exists(path))
|
|
{
|
|
throw new ArgumentException("Directory does not exist", nameof(path));
|
|
}
|
|
|
|
string modelStateFilePath = Path.Combine(path, _modelStateFilename);
|
|
Executor.Context.LoadState(modelStateFilePath);
|
|
|
|
string executorStateFilepath = Path.Combine(path, _executorStateFilename);
|
|
((StatefulExecutorBase)Executor).LoadState(executorStateFilepath);
|
|
|
|
string historyFilepath = Path.Combine(path, _hsitoryFilename);
|
|
string historyJson = File.ReadAllText(historyFilepath);
|
|
History = ChatHistory.FromJson(historyJson)
|
|
?? throw new ArgumentException("History file is invalid", nameof(path));
|
|
}
|
|
|
|
/// <summary>
|
|
/// Add a message to the chat history.
|
|
/// </summary>
|
|
/// <param name="message"></param>
|
|
/// <returns></returns>
|
|
public ChatSession AddMessage(ChatHistory.Message message)
|
|
{
|
|
// If current message is a system message, only allow the history to be empty
|
|
if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0)
|
|
{
|
|
throw new ArgumentException("Cannot add a system message after another message", nameof(message));
|
|
}
|
|
|
|
// If current message is a user message, only allow the history to be empty,
|
|
// or the previous message to be a system message or assistant message.
|
|
if (message.AuthorRole == AuthorRole.User)
|
|
{
|
|
ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
|
|
if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User)
|
|
{
|
|
throw new ArgumentException("Cannot add a user message after another user message", nameof(message));
|
|
}
|
|
}
|
|
|
|
// If the current message is an assistant message,
|
|
// the previous message must be a user message.
|
|
if (message.AuthorRole == AuthorRole.Assistant)
|
|
{
|
|
ChatHistory.Message? lastMessage = History.Messages.LastOrDefault();
|
|
if (lastMessage is null
|
|
|| lastMessage.AuthorRole != AuthorRole.User)
|
|
{
|
|
throw new ArgumentException("Assistant message must be preceded with a user message", nameof(message));
|
|
}
|
|
}
|
|
|
|
History.AddMessage(message.AuthorRole, message.Content);
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Add a system message to the chat history.
|
|
/// </summary>
|
|
/// <param name="content"></param>
|
|
/// <returns></returns>
|
|
public ChatSession AddSystemMessage(string content)
|
|
=> AddMessage(new ChatHistory.Message(AuthorRole.System, content));
|
|
|
|
/// <summary>
|
|
/// Add an assistant message to the chat history.
|
|
/// </summary>
|
|
/// <param name="content"></param>
|
|
/// <returns></returns>
|
|
public ChatSession AddAssistantMessage(string content)
|
|
=> AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
|
|
|
|
/// <summary>
|
|
/// Add a user message to the chat history.
|
|
/// </summary>
|
|
/// <param name="content"></param>
|
|
/// <returns></returns>
|
|
public ChatSession AddUserMessage(string content)
|
|
=> AddMessage(new ChatHistory.Message(AuthorRole.User, content));
|
|
|
|
/// <summary>
|
|
/// Remove the last message from the chat history.
|
|
/// </summary>
|
|
/// <returns></returns>
|
|
public ChatSession RemoveLastMessage()
|
|
{
|
|
History.Messages.RemoveAt(History.Messages.Count - 1);
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Replace a user message with a new message and remove all messages after the new message.
|
|
/// This is useful when the user wants to edit a message. And regenerate the response.
|
|
/// </summary>
|
|
/// <param name="oldMessage"></param>
|
|
/// <param name="newMessage"></param>
|
|
/// <returns></returns>
|
|
public ChatSession ReplaceUserMessage(
|
|
ChatHistory.Message oldMessage,
|
|
ChatHistory.Message newMessage)
|
|
{
|
|
if (oldMessage.AuthorRole != AuthorRole.User)
|
|
{
|
|
throw new ArgumentException("Old message must be a user message", nameof(oldMessage));
|
|
}
|
|
|
|
if (newMessage.AuthorRole != AuthorRole.User)
|
|
{
|
|
throw new ArgumentException("New message must be a user message", nameof(newMessage));
|
|
}
|
|
|
|
int index = History.Messages.IndexOf(oldMessage);
|
|
if (index == -1)
|
|
{
|
|
throw new ArgumentException("Old message does not exist in history", nameof(oldMessage));
|
|
}
|
|
|
|
History.Messages[index] = newMessage;
|
|
|
|
// Remove all message after the new message
|
|
History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1);
|
|
|
|
return this;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Chat with the model.
|
|
/// </summary>
|
|
/// <param name="message"></param>
|
|
/// <param name="inferenceParams"></param>
|
|
/// <param name="applyInputTransformPipeline"></param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ArgumentException"></exception>
|
|
public async IAsyncEnumerable<string> ChatAsync(
|
|
ChatHistory.Message message,
|
|
bool applyInputTransformPipeline,
|
|
IInferenceParams? inferenceParams = null,
|
|
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
|
{
|
|
// The message must be a user message
|
|
if (message.AuthorRole != AuthorRole.User)
|
|
{
|
|
throw new ArgumentException("Message must be a user message", nameof(message));
|
|
}
|
|
|
|
// Apply input transform pipeline
|
|
if (applyInputTransformPipeline)
|
|
{
|
|
foreach (var inputTransform in InputTransformPipeline)
|
|
{
|
|
message.Content = inputTransform.Transform(message.Content);
|
|
}
|
|
}
|
|
|
|
// Add the user's message to the history
|
|
AddUserMessage(message.Content);
|
|
|
|
// Prepare prompt variable
|
|
string prompt;
|
|
|
|
// Check if the session history was restored from a previous session
|
|
// or added as part of new chat session history.
|
|
InteractiveExecutorState state = (InteractiveExecutorState)((StatefulExecutorBase)Executor).GetStateData();
|
|
|
|
// If "IsPromptRun" is true, the session was newly started.
|
|
if (state.IsPromptRun)
|
|
{
|
|
// If the session history was added as part of new chat session history,
|
|
// convert the complete history includsing system message and manually added history
|
|
// to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
|
|
prompt = HistoryTransform.HistoryToText(History);
|
|
}
|
|
else
|
|
{
|
|
// If the session was restored from a previous session,
|
|
// convert only the current message to the prompt with the prompt template
|
|
// specified in the HistoryTransform class implementation that is provided.
|
|
ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content);
|
|
prompt = HistoryTransform.HistoryToText(singleMessageHistory);
|
|
}
|
|
|
|
string assistantMessage = string.Empty;
|
|
|
|
await foreach (
|
|
string textToken
|
|
in ChatAsyncInternal(
|
|
prompt,
|
|
inferenceParams,
|
|
cancellationToken))
|
|
{
|
|
assistantMessage += textToken;
|
|
yield return textToken;
|
|
}
|
|
|
|
// Add the assistant message to the history
|
|
AddAssistantMessage(assistantMessage);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Chat with the model.
|
|
/// </summary>
|
|
/// <param name="message"></param>
|
|
/// <param name="inferenceParams"></param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
public IAsyncEnumerable<string> ChatAsync(
|
|
ChatHistory.Message message,
|
|
IInferenceParams? inferenceParams = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
return ChatAsync(
|
|
message,
|
|
applyInputTransformPipeline: true,
|
|
inferenceParams,
|
|
cancellationToken);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Chat with the model.
|
|
/// </summary>
|
|
/// <param name="history"></param>
|
|
/// <param name="applyInputTransformPipeline"></param>
|
|
/// <param name="inferenceParams"></param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="ArgumentException"></exception>
|
|
public IAsyncEnumerable<string> ChatAsync(
|
|
ChatHistory history,
|
|
bool applyInputTransformPipeline,
|
|
IInferenceParams? inferenceParams = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
ChatHistory.Message lastMessage = history.Messages.LastOrDefault()
|
|
?? throw new ArgumentException("History must contain at least one message", nameof(history));
|
|
|
|
foreach (
|
|
ChatHistory.Message message
|
|
in history.Messages.Take(history.Messages.Count - 1))
|
|
{
|
|
// Apply input transform pipeline
|
|
if (applyInputTransformPipeline
|
|
&& message.AuthorRole == AuthorRole.User)
|
|
{
|
|
foreach (
|
|
var inputTransform
|
|
in InputTransformPipeline)
|
|
{
|
|
message.Content = inputTransform.Transform(message.Content);
|
|
}
|
|
}
|
|
|
|
AddMessage(message);
|
|
}
|
|
|
|
return ChatAsync(
|
|
lastMessage,
|
|
applyInputTransformPipeline,
|
|
inferenceParams,
|
|
cancellationToken);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Chat with the model.
|
|
/// </summary>
|
|
/// <param name="history"></param>
|
|
/// <param name="inferenceParams"></param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
public IAsyncEnumerable<string> ChatAsync(
|
|
ChatHistory history,
|
|
IInferenceParams? inferenceParams = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
return ChatAsync(
|
|
history,
|
|
applyInputTransformPipeline: true,
|
|
inferenceParams,
|
|
cancellationToken);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Regenerate the last assistant message.
|
|
/// </summary>
|
|
/// <param name="inferenceParams"></param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="InvalidOperationException"></exception>
|
|
public async IAsyncEnumerable<string> RegenerateAssistantMessageAsync(
|
|
InferenceParams? inferenceParams = null,
|
|
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
|
{
|
|
// Make sure the last message is an assistant message (reponse from the LLM).
|
|
ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault();
|
|
|
|
if (lastAssistantMessage is null
|
|
|| lastAssistantMessage.AuthorRole != AuthorRole.Assistant)
|
|
{
|
|
throw new InvalidOperationException("Last message must be an assistant message");
|
|
}
|
|
|
|
// Remove the last assistant message from the history.
|
|
RemoveLastMessage();
|
|
|
|
// Get the last user message.
|
|
ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault();
|
|
|
|
if (lastUserMessage is null
|
|
|| lastUserMessage.AuthorRole != AuthorRole.User)
|
|
{
|
|
throw new InvalidOperationException("Last message must be a user message");
|
|
}
|
|
|
|
// Remove the last user message from the history.
|
|
RemoveLastMessage();
|
|
|
|
// Regenerate the assistant message.
|
|
await foreach (
|
|
string textToken
|
|
in ChatAsync(
|
|
lastUserMessage,
|
|
applyInputTransformPipeline: false,
|
|
inferenceParams,
|
|
cancellationToken))
|
|
{
|
|
yield return textToken;
|
|
}
|
|
}
|
|
|
|
private async IAsyncEnumerable<string> ChatAsyncInternal(
|
|
string prompt,
|
|
IInferenceParams? inferenceParams = null,
|
|
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
|
{
|
|
var results = Executor.InferAsync(prompt, inferenceParams, cancellationToken);
|
|
|
|
await foreach (
|
|
string textToken
|
|
in OutputTransform
|
|
.TransformAsync(results)
|
|
.WithCancellation(cancellationToken))
|
|
{
|
|
yield return textToken;
|
|
}
|
|
}
|
|
}
|