Update the StatefulChatService to use new ChatSession integration

This commit is contained in:
Philipp Bauer 2023-12-10 09:34:32 -06:00
parent f669a4f5a7
commit 29c5c6e93c
1 changed files with 19 additions and 16 deletions

View File

@ -11,8 +11,7 @@ public class StatefulChatService : IDisposable
private readonly LLamaContext _context; private readonly LLamaContext _context;
private bool _continue = false; private bool _continue = false;
private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n" private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.";
+ "User: ";
public StatefulChatService(IConfiguration configuration) public StatefulChatService(IConfiguration configuration)
{ {
@ -25,7 +24,9 @@ public class StatefulChatService : IDisposable
using var weights = LLamaWeights.LoadFromFile(@params); using var weights = LLamaWeights.LoadFromFile(@params);
_context = new LLamaContext(weights, @params); _context = new LLamaContext(weights, @params);
_session = new ChatSession(new InteractiveExecutor(_context)); _session = new ChatSession(new InteractiveExecutor(_context));
_session.History.AddMessage(Common.AuthorRole.System, SystemPrompt);
} }
public void Dispose() public void Dispose()
@ -35,10 +36,8 @@ public class StatefulChatService : IDisposable
public async Task<string> Send(SendMessageInput input) public async Task<string> Send(SendMessageInput input)
{ {
var userInput = input.Text;
if (!_continue) if (!_continue)
{ {
userInput = SystemPrompt + userInput;
Console.Write(SystemPrompt); Console.Write(SystemPrompt);
_continue = true; _continue = true;
} }
@ -47,11 +46,14 @@ public class StatefulChatService : IDisposable
Console.Write(input.Text); Console.Write(input.Text);
Console.ForegroundColor = ConsoleColor.White; Console.ForegroundColor = ConsoleColor.White;
var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() var outputs = _session.ChatAsync(
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text),
new Common.InferenceParams()
{ {
RepeatPenalty = 1.0f, RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" }, AntiPrompts = new string[] { "User:" },
}); });
var result = ""; var result = "";
await foreach (var output in outputs) await foreach (var output in outputs)
{ {
@ -64,10 +66,8 @@ public class StatefulChatService : IDisposable
public async IAsyncEnumerable<string> SendStream(SendMessageInput input) public async IAsyncEnumerable<string> SendStream(SendMessageInput input)
{ {
var userInput = input.Text;
if (!_continue) if (!_continue)
{ {
userInput = SystemPrompt + userInput;
Console.Write(SystemPrompt); Console.Write(SystemPrompt);
_continue = true; _continue = true;
} }
@ -76,11 +76,14 @@ public class StatefulChatService : IDisposable
Console.Write(input.Text); Console.Write(input.Text);
Console.ForegroundColor = ConsoleColor.White; Console.ForegroundColor = ConsoleColor.White;
var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() var outputs = _session.ChatAsync(
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text)
, new Common.InferenceParams()
{ {
RepeatPenalty = 1.0f, RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" }, AntiPrompts = new string[] { "User:" },
}); });
await foreach (var output in outputs) await foreach (var output in outputs)
{ {
Console.Write(output); Console.Write(output);