diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs index 7fda3d4f..4571aec8 100644 --- a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs +++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs @@ -1,13 +1,6 @@ using LLama; using Microsoft.SemanticKernel.AI.ChatCompletion; -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; using System.Runtime.CompilerServices; -using System.Text; -using System.Threading; -using System.Threading.Tasks; namespace LLamaSharp.SemanticKernel.ChatCompletion; @@ -19,12 +12,32 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion private const string UserRole = "user:"; private const string AssistantRole = "assistant:"; private ChatSession session; + private ChatRequestSettings defaultRequestSettings; - public LLamaSharpChatCompletion(InteractiveExecutor model) + public LLamaSharpChatCompletion(InteractiveExecutor model, ChatRequestSettings? defaultRequestSettings = default) { this.session = new ChatSession(model) .WithHistoryTransform(new HistoryTransform()) .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { UserRole, AssistantRole })); + this.defaultRequestSettings = defaultRequestSettings ??= new ChatRequestSettings() + { + MaxTokens = 256, + Temperature = 0, + TopP = 0, + StopSequences = new List { } + }; + } + + public LLamaSharpChatCompletion(ChatSession session, ChatRequestSettings? defaultRequestSettings = default) + { + this.session = session; + this.defaultRequestSettings = defaultRequestSettings ??= new ChatRequestSettings() + { + MaxTokens = 256, + Temperature = 0, + TopP = 0, + StopSequences = new List { } + }; } /// @@ -43,13 +56,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion /// public async Task> GetChatCompletionsAsync(ChatHistory chat, ChatRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) { - requestSettings ??= new ChatRequestSettings() - { - MaxTokens = 256, - Temperature = 0, - TopP = 0, - StopSequences = new List { } - }; + requestSettings = requestSettings ?? this.defaultRequestSettings; var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), requestSettings.ToLLamaSharpInferenceParams(), cancellationToken); @@ -59,13 +66,7 @@ public sealed class LLamaSharpChatCompletion : IChatCompletion /// public async IAsyncEnumerable GetStreamingChatCompletionsAsync(ChatHistory chat, ChatRequestSettings? requestSettings = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - requestSettings ??= new ChatRequestSettings() - { - MaxTokens = 256, - Temperature = 0, - TopP = 0, - StopSequences = new List { } - }; + requestSettings = requestSettings ?? this.defaultRequestSettings; var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), requestSettings.ToLLamaSharpInferenceParams(), cancellationToken);