diff --git a/LLama.Examples/Examples/SemanticKernelChat.cs b/LLama.Examples/Examples/SemanticKernelChat.cs
index da5eb5ce..39870f1b 100644
--- a/LLama.Examples/Examples/SemanticKernelChat.cs
+++ b/LLama.Examples/Examples/SemanticKernelChat.cs
@@ -16,12 +16,11 @@ namespace LLama.Examples.Examples
// Load weights into memory
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
- using var context = model.CreateContext(parameters);
- var ex = new InteractiveExecutor(context);
+ var ex = new StatelessExecutor(model, parameters);
var chatGPT = new LLamaSharpChatCompletion(ex);
- var chatHistory = chatGPT.CreateNewChat("You are a librarian, expert about books");
+ var chatHistory = chatGPT.CreateNewChat("This is a conversation between the assistant and the user. \n\n You are a librarian, expert about books. ");
Console.WriteLine("Chat content:");
Console.WriteLine("------------------------");
diff --git a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs
index 759888d0..f1a0ebcb 100644
--- a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs
+++ b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs
@@ -1,4 +1,6 @@
-using static LLama.LLamaTransforms;
+using LLama.Common;
+using System.Text;
+using static LLama.LLamaTransforms;
namespace LLamaSharp.SemanticKernel.ChatCompletion;
@@ -10,8 +12,6 @@ public class HistoryTransform : DefaultHistoryTransform
///
public override string HistoryToText(global::LLama.Common.ChatHistory history)
{
- var prompt = base.HistoryToText(history);
- return prompt + "\nAssistant:";
-
+ return base.HistoryToText(history) + $"{AuthorRole.Assistant}: ";
}
}
diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
index 4fcb5baa..7e5425bb 100644
--- a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
+++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
@@ -1,7 +1,9 @@
using LLama;
+using LLama.Abstractions;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using System.Runtime.CompilerServices;
+using static LLama.LLamaTransforms;
namespace LLamaSharp.SemanticKernel.ChatCompletion;
@@ -10,10 +12,10 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
///
public sealed class LLamaSharpChatCompletion : IChatCompletion
{
- private const string UserRole = "user:";
- private const string AssistantRole = "assistant:";
- private ChatSession session;
+ private readonly StatelessExecutor _model;
private ChatRequestSettings defaultRequestSettings;
+ private readonly IHistoryTransform historyTransform;
+ private readonly ITextStreamTransform outputTransform;
private readonly Dictionary _attributes = new();
@@ -30,18 +32,17 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
};
}
- public LLamaSharpChatCompletion(InteractiveExecutor model, ChatRequestSettings? defaultRequestSettings = default)
+ public LLamaSharpChatCompletion(StatelessExecutor model,
+ ChatRequestSettings? defaultRequestSettings = default,
+ IHistoryTransform? historyTransform = null,
+ ITextStreamTransform? outputTransform = null)
{
- this.session = new ChatSession(model)
- .WithHistoryTransform(new HistoryTransform())
- .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { UserRole, AssistantRole }));
- this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings();
- }
-
- public LLamaSharpChatCompletion(ChatSession session, ChatRequestSettings? defaultRequestSettings = default)
- {
- this.session = session;
- this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings();
+ this._model = model;
+ this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings();
+ this.historyTransform = historyTransform ?? new HistoryTransform();
+ this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:",
+ $"{LLama.Common.AuthorRole.Assistant}:",
+ $"{LLama.Common.AuthorRole.System}:"});
}
///
@@ -60,14 +61,14 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
///
public Task> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default)
{
- var settings = requestSettings != null
+ var settings = requestSettings != null
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
+ var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory());
- // This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
- var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken);
+ var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);
- return Task.FromResult>(new List { new LLamaSharpChatResult(result) }.AsReadOnly());
+ return Task.FromResult>(new List { new LLamaSharpChatResult(outputTransform.TransformAsync(result)) }.AsReadOnly());
}
///
@@ -78,10 +79,10 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
var settings = requestSettings != null
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
-
+ var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory());
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
- var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken);
+ var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);
- yield return new LLamaSharpChatResult(result);
+ yield return new LLamaSharpChatResult(outputTransform.TransformAsync(result));
}
}
diff --git a/LLama.SemanticKernel/ExtensionMethods.cs b/LLama.SemanticKernel/ExtensionMethods.cs
index b3ff6a7b..6f39e373 100644
--- a/LLama.SemanticKernel/ExtensionMethods.cs
+++ b/LLama.SemanticKernel/ExtensionMethods.cs
@@ -35,7 +35,10 @@ public static class ExtensionMethods
throw new ArgumentNullException(nameof(requestSettings));
}
- var antiPrompts = new List(requestSettings.StopSequences) { AuthorRole.User.ToString() + ":" };
+ var antiPrompts = new List(requestSettings.StopSequences)
+ { LLama.Common.AuthorRole.User.ToString() + ":" ,
+ LLama.Common.AuthorRole.Assistant.ToString() + ":",
+ LLama.Common.AuthorRole.System.ToString() + ":"};
return new global::LLama.Common.InferenceParams
{
Temperature = (float)requestSettings.Temperature,