Merge pull request #341 from xbotter/sk/chat-stateless
🔧 Refactor Semantic Kernel chat completion implementation
This commit is contained in:
commit
3019da9d22
|
@ -16,12 +16,11 @@ namespace LLama.Examples.Examples
|
|||
// Load weights into memory
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var ex = new InteractiveExecutor(context);
|
||||
var ex = new StatelessExecutor(model, parameters);
|
||||
|
||||
var chatGPT = new LLamaSharpChatCompletion(ex);
|
||||
|
||||
var chatHistory = chatGPT.CreateNewChat("You are a librarian, expert about books");
|
||||
var chatHistory = chatGPT.CreateNewChat("This is a conversation between the assistant and the user. \n\n You are a librarian, expert about books. ");
|
||||
|
||||
Console.WriteLine("Chat content:");
|
||||
Console.WriteLine("------------------------");
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
using static LLama.LLamaTransforms;
|
||||
using LLama.Common;
|
||||
using System.Text;
|
||||
using static LLama.LLamaTransforms;
|
||||
|
||||
namespace LLamaSharp.SemanticKernel.ChatCompletion;
|
||||
|
||||
|
@ -10,8 +12,6 @@ public class HistoryTransform : DefaultHistoryTransform
|
|||
/// <inheritdoc/>
|
||||
public override string HistoryToText(global::LLama.Common.ChatHistory history)
|
||||
{
|
||||
var prompt = base.HistoryToText(history);
|
||||
return prompt + "\nAssistant:";
|
||||
|
||||
return base.HistoryToText(history) + $"{AuthorRole.Assistant}: ";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
using LLama;
|
||||
using LLama.Abstractions;
|
||||
using Microsoft.SemanticKernel.AI;
|
||||
using Microsoft.SemanticKernel.AI.ChatCompletion;
|
||||
using System.Runtime.CompilerServices;
|
||||
using static LLama.LLamaTransforms;
|
||||
|
||||
namespace LLamaSharp.SemanticKernel.ChatCompletion;
|
||||
|
||||
|
@ -10,10 +12,10 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
|
|||
/// </summary>
|
||||
public sealed class LLamaSharpChatCompletion : IChatCompletion
|
||||
{
|
||||
private const string UserRole = "user:";
|
||||
private const string AssistantRole = "assistant:";
|
||||
private ChatSession session;
|
||||
private readonly StatelessExecutor _model;
|
||||
private ChatRequestSettings defaultRequestSettings;
|
||||
private readonly IHistoryTransform historyTransform;
|
||||
private readonly ITextStreamTransform outputTransform;
|
||||
|
||||
private readonly Dictionary<string, string> _attributes = new();
|
||||
|
||||
|
@ -30,18 +32,17 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
|
|||
};
|
||||
}
|
||||
|
||||
public LLamaSharpChatCompletion(InteractiveExecutor model, ChatRequestSettings? defaultRequestSettings = default)
|
||||
public LLamaSharpChatCompletion(StatelessExecutor model,
|
||||
ChatRequestSettings? defaultRequestSettings = default,
|
||||
IHistoryTransform? historyTransform = null,
|
||||
ITextStreamTransform? outputTransform = null)
|
||||
{
|
||||
this.session = new ChatSession(model)
|
||||
.WithHistoryTransform(new HistoryTransform())
|
||||
.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { UserRole, AssistantRole }));
|
||||
this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings();
|
||||
}
|
||||
|
||||
public LLamaSharpChatCompletion(ChatSession session, ChatRequestSettings? defaultRequestSettings = default)
|
||||
{
|
||||
this.session = session;
|
||||
this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings();
|
||||
this._model = model;
|
||||
this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings();
|
||||
this.historyTransform = historyTransform ?? new HistoryTransform();
|
||||
this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:",
|
||||
$"{LLama.Common.AuthorRole.Assistant}:",
|
||||
$"{LLama.Common.AuthorRole.System}:"});
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
|
@ -60,14 +61,14 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
|
|||
/// <inheritdoc/>
|
||||
public Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var settings = requestSettings != null
|
||||
var settings = requestSettings != null
|
||||
? ChatRequestSettings.FromRequestSettings(requestSettings)
|
||||
: defaultRequestSettings;
|
||||
var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory());
|
||||
|
||||
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
|
||||
var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken);
|
||||
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);
|
||||
|
||||
return Task.FromResult<IReadOnlyList<IChatResult>>(new List<IChatResult> { new LLamaSharpChatResult(result) }.AsReadOnly());
|
||||
return Task.FromResult<IReadOnlyList<IChatResult>>(new List<IChatResult> { new LLamaSharpChatResult(outputTransform.TransformAsync(result)) }.AsReadOnly());
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
|
@ -78,10 +79,10 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion
|
|||
var settings = requestSettings != null
|
||||
? ChatRequestSettings.FromRequestSettings(requestSettings)
|
||||
: defaultRequestSettings;
|
||||
|
||||
var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory());
|
||||
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
|
||||
var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken);
|
||||
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);
|
||||
|
||||
yield return new LLamaSharpChatResult(result);
|
||||
yield return new LLamaSharpChatResult(outputTransform.TransformAsync(result));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,7 +35,10 @@ public static class ExtensionMethods
|
|||
throw new ArgumentNullException(nameof(requestSettings));
|
||||
}
|
||||
|
||||
var antiPrompts = new List<string>(requestSettings.StopSequences) { AuthorRole.User.ToString() + ":" };
|
||||
var antiPrompts = new List<string>(requestSettings.StopSequences)
|
||||
{ LLama.Common.AuthorRole.User.ToString() + ":" ,
|
||||
LLama.Common.AuthorRole.Assistant.ToString() + ":",
|
||||
LLama.Common.AuthorRole.System.ToString() + ":"};
|
||||
return new global::LLama.Common.InferenceParams
|
||||
{
|
||||
Temperature = (float)requestSettings.Temperature,
|
||||
|
|
Loading…
Reference in New Issue