LLamaSharp/LLama/ChatSession.cs

145 lines
6.1 KiB
C#
Raw Normal View History

using LLama.Common;
2023-05-11 03:19:12 +08:00
using System.Collections.Generic;
using System.Runtime.CompilerServices;
2023-05-11 03:19:12 +08:00
using System.Text;
using System.Threading;
2023-05-11 03:19:12 +08:00
namespace LLama
{
public class ChatSession
{
private readonly string defaultUserName = "User";
private readonly string defaultAssistantName = "Assistant";
private readonly string defaultSystemName = "System";
private readonly string defaultUnknownName = "??";
private ILLamaExecutor _executor;
private ChatHistory _history;
public ILLamaExecutor Executor => _executor;
public ChatHistory History => _history;
public SessionParams Params { get; set; }
public ChatSession(ILLamaExecutor executor, SessionParams? sessionParams = null)
{
_executor = executor;
_history = new ChatHistory();
Params = sessionParams ?? new SessionParams();
}
public virtual string BuildTextFromHistory(ChatHistory history)
{
StringBuilder sb = new();
var userName = Params.UserName ?? defaultUserName;
var assistantName = Params.AssistantName ?? defaultAssistantName;
var systemName = Params.SystemName ?? defaultSystemName;
foreach (var message in history.Messages)
{
if (message.AuthorRole == AuthorRole.User)
{
sb.AppendLine($"{userName}: {message.Content}");
}
else if (message.AuthorRole == AuthorRole.System)
{
sb.AppendLine($"{systemName}: {message.Content}");
}
else if (message.AuthorRole == AuthorRole.Unknown)
{
sb.AppendLine($"{defaultUnknownName}: {message.Content}");
}
else if (message.AuthorRole == AuthorRole.Assistant)
{
sb.AppendLine($"{assistantName}: {message.Content}");
}
}
return sb.ToString();
}
public virtual string CropNameFromText(string text, AuthorRole role)
{
if (!string.IsNullOrEmpty(Params.UserName) && role == AuthorRole.User && text.StartsWith($"{Params.UserName}:"))
{
text = text.Substring($"{Params.UserName}:".Length).TrimStart();
}
else if (!string.IsNullOrEmpty(Params.AssistantName) && role == AuthorRole.Assistant && text.EndsWith($"{Params.AssistantName}:"))
{
text = text.Substring(0, text.Length - $"{Params.AssistantName}:".Length).TrimEnd();
}
if (_executor is InstructExecutor && role == AuthorRole.Assistant && text.EndsWith("\n> "))
{
text = text.Substring(0, text.Length - "\n> ".Length).TrimEnd();
}
return text;
}
/// <summary>
/// Get the response from the LLama model with chat histories.
/// </summary>
/// <param name="prompt"></param>
/// <param name="inferenceParams"></param>
/// <returns></returns>
public IEnumerable<string> Chat(ChatHistory history, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
var prompt = BuildTextFromHistory(history);
History.AddMessage(AuthorRole.User, prompt);
StringBuilder sb = new();
foreach (var result in _executor.Infer(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.AddMessage(AuthorRole.Assistant, CropNameFromText(sb.ToString(), AuthorRole.Assistant));
}
/// <summary>
/// Get the response from the LLama model. Note that prompt could not only be the preset words,
/// but also the question you want to ask.
/// </summary>
/// <param name="prompt"></param>
/// <param name="inferenceParams"></param>
/// <returns></returns>
public IEnumerable<string> Chat(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
History.AddMessage(AuthorRole.User, prompt);
StringBuilder sb = new();
foreach (var result in _executor.Infer(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.AddMessage(AuthorRole.Assistant, CropNameFromText(sb.ToString(), AuthorRole.Assistant));
}
/// <summary>
/// Get the response from the LLama model with chat histories.
/// </summary>
/// <param name="prompt"></param>
/// <param name="inferenceParams"></param>
/// <returns></returns>
public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var prompt = BuildTextFromHistory(history);
History.AddMessage(AuthorRole.User, prompt);
StringBuilder sb = new();
await foreach (var result in _executor.InferAsync(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.AddMessage(AuthorRole.Assistant, CropNameFromText(sb.ToString(), AuthorRole.Assistant));
}
public async IAsyncEnumerable<string> ChatAsync(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
History.AddMessage(AuthorRole.User, prompt);
StringBuilder sb = new();
await foreach (var result in _executor.InferAsync(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.AddMessage(AuthorRole.Assistant, CropNameFromText(sb.ToString(), AuthorRole.Assistant));
}
}
}