🔧 Refactor chat completion implementation

- Refactored the chat completion implementation in `LLamaSharpChatCompletion.cs` to use `StatelessExecutor` instead of `InteractiveExecutor`.
- Updated the chat history prompt in `LLamaSharpChatCompletion.cs` to include a conversation between the assistant and the user.
- Modified the `HistoryTransform` class in `HistoryTransform.cs` to append the assistant role to the chat history prompt.
- Updated the constructor of `LLamaSharpChatCompletion` to accept optional parameters for `historyTransform` and `outputTransform`.
- Modified the `GetChatCompletionsAsync` and `GetChatCompletions` methods in `LLamaSharpChatCompletion.cs` to use the new `StatelessExecutor` and `outputTransform`.
- Updated the `ExtensionMethods.cs` file to include the assistant and system roles in the list of anti-prompts.
This commit is contained in:
xbotter 2023-12-01 21:39:31 +08:00
parent 884f5ade13
commit a2b26faa7a
No known key found for this signature in database
GPG Key ID: D299220A7FE5CF1E
4 changed files with 32 additions and 29 deletions

View File

@ -16,12 +16,11 @@ namespace LLama.Examples.Examples
// Load weights into memory
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);
var ex = new StatelessExecutor(model, parameters);
var chatGPT = new LLamaSharpChatCompletion(ex);
var chatHistory = chatGPT.CreateNewChat("You are a librarian, expert about books");
var chatHistory = chatGPT.CreateNewChat("This is a conversation between the assistant and the user. \n\n You are a librarian, expert about books. ");
Console.WriteLine("Chat content:");
Console.WriteLine("------------------------");

View File

@ -1,4 +1,6 @@
using static LLama.LLamaTransforms;
using LLama.Common;
using System.Text;
using static LLama.LLamaTransforms;
namespace LLamaSharp.SemanticKernel.ChatCompletion;
@ -10,8 +12,6 @@ public class HistoryTransform : DefaultHistoryTransform
/// <inheritdoc/>
public override string HistoryToText(global::LLama.Common.ChatHistory history)
{
var prompt = base.HistoryToText(history);
return prompt + "\nAssistant:";
return base.HistoryToText(history) + $"{AuthorRole.Assistant}: ";
}
}

View File

@ -1,7 +1,9 @@
using LLama;
using LLama.Abstractions;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using System.Runtime.CompilerServices;
using static LLama.LLamaTransforms;
namespace LLamaSharp.SemanticKernel.ChatCompletion;
@ -10,10 +12,10 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
/// </summary>
public sealed class LLamaSharpChatCompletion : IChatCompletion
{
private const string UserRole = "user:";
private const string AssistantRole = "assistant:";
private ChatSession session;
private readonly StatelessExecutor _model;
private ChatRequestSettings defaultRequestSettings;
private readonly IHistoryTransform historyTransform;
private readonly ITextStreamTransform outputTransform;
private readonly Dictionary<string, string> _attributes = new();
@ -30,18 +32,17 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
};
}
public LLamaSharpChatCompletion(InteractiveExecutor model, ChatRequestSettings? defaultRequestSettings = default)
public LLamaSharpChatCompletion(StatelessExecutor model,
ChatRequestSettings? defaultRequestSettings = default,
IHistoryTransform? historyTransform = null,
ITextStreamTransform? outputTransform = null)
{
this.session = new ChatSession(model)
.WithHistoryTransform(new HistoryTransform())
.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { UserRole, AssistantRole }));
this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings();
}
public LLamaSharpChatCompletion(ChatSession session, ChatRequestSettings? defaultRequestSettings = default)
{
this.session = session;
this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings();
this._model = model;
this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings();
this.historyTransform = historyTransform ?? new HistoryTransform();
this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:",
$"{LLama.Common.AuthorRole.Assistant}:",
$"{LLama.Common.AuthorRole.System}:"});
}
/// <inheritdoc/>
@ -63,11 +64,11 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
var settings = requestSettings != null
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory());
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);
return Task.FromResult<IReadOnlyList<IChatResult>>(new List<IChatResult> { new LLamaSharpChatResult(result) }.AsReadOnly());
return Task.FromResult<IReadOnlyList<IChatResult>>(new List<IChatResult> { new LLamaSharpChatResult(outputTransform.TransformAsync(result)) }.AsReadOnly());
}
/// <inheritdoc/>
@ -78,10 +79,10 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
var settings = requestSettings != null
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory());
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);
yield return new LLamaSharpChatResult(result);
yield return new LLamaSharpChatResult(outputTransform.TransformAsync(result));
}
}

View File

@ -35,7 +35,10 @@ public static class ExtensionMethods
throw new ArgumentNullException(nameof(requestSettings));
}
var antiPrompts = new List<string>(requestSettings.StopSequences) { AuthorRole.User.ToString() + ":" };
var antiPrompts = new List<string>(requestSettings.StopSequences)
{ LLama.Common.AuthorRole.User.ToString() + ":" ,
LLama.Common.AuthorRole.Assistant.ToString() + ":",
LLama.Common.AuthorRole.System.ToString() + ":"};
return new global::LLama.Common.InferenceParams
{
Temperature = (float)requestSettings.Temperature,