diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs index f1eb3538..f45c98ee 100644 --- a/LLama.WebAPI/Services/StatefulChatService.cs +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -11,8 +11,7 @@ public class StatefulChatService : IDisposable private readonly LLamaContext _context; private bool _continue = false; - private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n" - + "User: "; + private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision."; public StatefulChatService(IConfiguration configuration) { @@ -25,7 +24,9 @@ public class StatefulChatService : IDisposable using var weights = LLamaWeights.LoadFromFile(@params); _context = new LLamaContext(weights, @params); + _session = new ChatSession(new InteractiveExecutor(_context)); + _session.History.AddMessage(Common.AuthorRole.System, SystemPrompt); } public void Dispose() @@ -35,10 +36,8 @@ public class StatefulChatService : IDisposable public async Task Send(SendMessageInput input) { - var userInput = input.Text; if (!_continue) { - userInput = SystemPrompt + userInput; Console.Write(SystemPrompt); _continue = true; } @@ -47,11 +46,14 @@ public class StatefulChatService : IDisposable Console.Write(input.Text); Console.ForegroundColor = ConsoleColor.White; - var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() - { - RepeatPenalty = 1.0f, - AntiPrompts = new string[] { "User:" }, - }); + var outputs = _session.ChatAsync( + new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text), + new Common.InferenceParams() + { + RepeatPenalty = 1.0f, + AntiPrompts = new string[] { "User:" }, + }); + var result = ""; await foreach (var output in outputs) { @@ -64,10 +66,8 @@ public class StatefulChatService : IDisposable public async IAsyncEnumerable SendStream(SendMessageInput input) { - var userInput = input.Text; if (!_continue) { - userInput = SystemPrompt + userInput; Console.Write(SystemPrompt); _continue = true; } @@ -76,11 +76,14 @@ public class StatefulChatService : IDisposable Console.Write(input.Text); Console.ForegroundColor = ConsoleColor.White; - var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() - { - RepeatPenalty = 1.0f, - AntiPrompts = new string[] { "User:" }, - }); + var outputs = _session.ChatAsync( + new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text) + , new Common.InferenceParams() + { + RepeatPenalty = 1.0f, + AntiPrompts = new string[] { "User:" }, + }); + await foreach (var output in outputs) { Console.Write(output);