diff --git a/LLama.Examples/Examples/SemanticKernelChat.cs b/LLama.Examples/Examples/SemanticKernelChat.cs index da5eb5ce..39870f1b 100644 --- a/LLama.Examples/Examples/SemanticKernelChat.cs +++ b/LLama.Examples/Examples/SemanticKernelChat.cs @@ -16,12 +16,11 @@ namespace LLama.Examples.Examples // Load weights into memory var parameters = new ModelParams(modelPath); using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters); - var ex = new InteractiveExecutor(context); + var ex = new StatelessExecutor(model, parameters); var chatGPT = new LLamaSharpChatCompletion(ex); - var chatHistory = chatGPT.CreateNewChat("You are a librarian, expert about books"); + var chatHistory = chatGPT.CreateNewChat("This is a conversation between the assistant and the user. \n\n You are a librarian, expert about books. "); Console.WriteLine("Chat content:"); Console.WriteLine("------------------------"); diff --git a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs index 759888d0..f1a0ebcb 100644 --- a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs +++ b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs @@ -1,4 +1,6 @@ -using static LLama.LLamaTransforms; +using LLama.Common; +using System.Text; +using static LLama.LLamaTransforms; namespace LLamaSharp.SemanticKernel.ChatCompletion; @@ -10,8 +12,6 @@ public class HistoryTransform : DefaultHistoryTransform /// public override string HistoryToText(global::LLama.Common.ChatHistory history) { - var prompt = base.HistoryToText(history); - return prompt + "\nAssistant:"; - + return base.HistoryToText(history) + $"{AuthorRole.Assistant}: "; } } diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs index 4fcb5baa..7e5425bb 100644 --- a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs +++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs @@ -1,7 +1,9 @@ using LLama; +using LLama.Abstractions; using Microsoft.SemanticKernel.AI; using Microsoft.SemanticKernel.AI.ChatCompletion; using System.Runtime.CompilerServices; +using static LLama.LLamaTransforms; namespace LLamaSharp.SemanticKernel.ChatCompletion; @@ -10,10 +12,10 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion; /// public sealed class LLamaSharpChatCompletion : IChatCompletion { - private const string UserRole = "user:"; - private const string AssistantRole = "assistant:"; - private ChatSession session; + private readonly StatelessExecutor _model; private ChatRequestSettings defaultRequestSettings; + private readonly IHistoryTransform historyTransform; + private readonly ITextStreamTransform outputTransform; private readonly Dictionary _attributes = new(); @@ -30,18 +32,17 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion }; } - public LLamaSharpChatCompletion(InteractiveExecutor model, ChatRequestSettings? defaultRequestSettings = default) + public LLamaSharpChatCompletion(StatelessExecutor model, + ChatRequestSettings? defaultRequestSettings = default, + IHistoryTransform? historyTransform = null, + ITextStreamTransform? outputTransform = null) { - this.session = new ChatSession(model) - .WithHistoryTransform(new HistoryTransform()) - .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { UserRole, AssistantRole })); - this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings(); - } - - public LLamaSharpChatCompletion(ChatSession session, ChatRequestSettings? defaultRequestSettings = default) - { - this.session = session; - this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings(); + this._model = model; + this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings(); + this.historyTransform = historyTransform ?? new HistoryTransform(); + this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:", + $"{LLama.Common.AuthorRole.Assistant}:", + $"{LLama.Common.AuthorRole.System}:"}); } /// @@ -60,14 +61,14 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion /// public Task> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) { - var settings = requestSettings != null + var settings = requestSettings != null ? ChatRequestSettings.FromRequestSettings(requestSettings) : defaultRequestSettings; + var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory()); - // This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable. - var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken); + var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); - return Task.FromResult>(new List { new LLamaSharpChatResult(result) }.AsReadOnly()); + return Task.FromResult>(new List { new LLamaSharpChatResult(outputTransform.TransformAsync(result)) }.AsReadOnly()); } /// @@ -78,10 +79,10 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion var settings = requestSettings != null ? ChatRequestSettings.FromRequestSettings(requestSettings) : defaultRequestSettings; - + var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory()); // This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable. - var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken); + var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken); - yield return new LLamaSharpChatResult(result); + yield return new LLamaSharpChatResult(outputTransform.TransformAsync(result)); } } diff --git a/LLama.SemanticKernel/ExtensionMethods.cs b/LLama.SemanticKernel/ExtensionMethods.cs index b3ff6a7b..6f39e373 100644 --- a/LLama.SemanticKernel/ExtensionMethods.cs +++ b/LLama.SemanticKernel/ExtensionMethods.cs @@ -35,7 +35,10 @@ public static class ExtensionMethods throw new ArgumentNullException(nameof(requestSettings)); } - var antiPrompts = new List(requestSettings.StopSequences) { AuthorRole.User.ToString() + ":" }; + var antiPrompts = new List(requestSettings.StopSequences) + { LLama.Common.AuthorRole.User.ToString() + ":" , + LLama.Common.AuthorRole.Assistant.ToString() + ":", + LLama.Common.AuthorRole.System.ToString() + ":"}; return new global::LLama.Common.InferenceParams { Temperature = (float)requestSettings.Temperature,