Merge pull request #39 from xbotter/webapi-example

update webapi example
This commit is contained in:
Rinne 2023-07-03 23:33:49 +08:00 committed by GitHub
commit a53ede191e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 175 additions and 41 deletions

View File

@ -1,3 +1,4 @@
using LLama.Common;
using LLama.WebAPI.Models;
using LLama.WebAPI.Services;
using Microsoft.AspNetCore.Mvc;
@ -9,20 +10,44 @@ namespace LLama.WebAPI.Controllers
[Route("[controller]")]
public class ChatController : ControllerBase
{
private readonly ChatService _service;
private readonly ILogger<ChatController> _logger;
public ChatController(ILogger<ChatController> logger,
ChatService service)
public ChatController(ILogger<ChatController> logger)
{
_logger = logger;
_service = service;
}
[HttpPost("Send")]
public string SendMessage([FromBody] SendMessageInput input)
public string SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service)
{
return _service.Send(input);
}
[HttpPost("Send/Stream")]
public async Task SendMessageStream([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service, CancellationToken cancellationToken)
{
Response.ContentType = "text/event-stream";
await foreach (var r in _service.SendStream(input))
{
await Response.WriteAsync("data:" + r + "\n\n", cancellationToken);
await Response.Body.FlushAsync(cancellationToken);
}
await Response.CompleteAsync();
}
[HttpPost("History")]
public async Task<string> SendHistory([FromBody] HistoryInput input, [FromServices] StatelessChatService _service)
{
var history = new ChatHistory();
var messages = input.Messages.Select(m => new ChatHistory.Message(Enum.Parse<AuthorRole>(m.Role), m.Content));
history.Messages.AddRange(messages);
return await _service.SendAsync(history);
}
}
}

View File

@ -7,6 +7,7 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.VisualStudio.Validation" Version="17.6.11" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.2.3" />
</ItemGroup>

View File

@ -4,3 +4,13 @@ public class SendMessageInput
{
public string Text { get; set; }
}
public class HistoryInput
{
public List<HistoryItem> Messages { get; set; }
public class HistoryItem
{
public string Role { get; set; }
public string Content { get; set; }
}
}

View File

@ -9,7 +9,8 @@ builder.Services.AddControllers();
builder.Services.AddEndpointsApiExplorer();
builder.Services.AddSwaggerGen();
builder.Services.AddSingleton<ChatService>();
builder.Services.AddSingleton<StatefulChatService>();
builder.Services.AddScoped<StatelessChatService>();
var app = builder.Build();

View File

@ -1,34 +0,0 @@
using LLama.OldVersion;
using LLama.WebAPI.Models;
namespace LLama.WebAPI.Services;
public class ChatService
{
private readonly ChatSession<LLamaModel> _session;
public ChatService()
{
LLamaModel model = new(new LLamaParams(model: @"ggml-model-q4_0.bin", n_ctx: 512, interactive: true, repeat_penalty: 1.0f, verbose_prompt: false));
_session = new ChatSession<LLamaModel>(model)
.WithPromptFile(@"Assets\chat-with-bob.txt")
.WithAntiprompt(new string[] { "User:" });
}
public string Send(SendMessageInput input)
{
Console.ForegroundColor = ConsoleColor.Green;
Console.WriteLine(input.Text);
Console.ForegroundColor = ConsoleColor.White;
var outputs = _session.Chat(input.Text);
var result = "";
foreach (var output in outputs)
{
Console.Write(output);
result += output;
}
return result;
}
}

View File

@ -0,0 +1,82 @@

using LLama.WebAPI.Models;
using Microsoft;
using System.Runtime.CompilerServices;
namespace LLama.WebAPI.Services;
public class StatefulChatService : IDisposable
{
private readonly ChatSession _session;
private readonly LLamaModel _model;
private bool _continue = false;
private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n"
+ "User: ";
public StatefulChatService(IConfiguration configuration)
{
_model = new LLamaModel(new Common.ModelParams(configuration["ModelPath"], contextSize: 512));
_session = new ChatSession(new InteractiveExecutor(_model));
}
public void Dispose()
{
_model?.Dispose();
}
public string Send(SendMessageInput input)
{
var userInput = input.Text;
if (!_continue)
{
userInput = SystemPrompt + userInput;
Console.Write(SystemPrompt);
_continue = true;
}
Console.ForegroundColor = ConsoleColor.Green;
Console.Write(input.Text);
Console.ForegroundColor = ConsoleColor.White;
var outputs = _session.Chat(userInput, new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
});
var result = "";
foreach (var output in outputs)
{
Console.Write(output);
result += output;
}
return result;
}
public async IAsyncEnumerable<string> SendStream(SendMessageInput input)
{
var userInput = input.Text;
if (!_continue)
{
userInput = SystemPrompt + userInput;
Console.Write(SystemPrompt);
_continue = true;
}
Console.ForegroundColor = ConsoleColor.Green;
Console.Write(input.Text);
Console.ForegroundColor = ConsoleColor.White;
var outputs = _session.ChatAsync(userInput, new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
});
await foreach (var output in outputs)
{
Console.Write(output);
yield return output;
}
}
}

View File

@ -0,0 +1,48 @@
using LLama.Common;
using Microsoft.AspNetCore.Http;
using System.Text;
using static LLama.LLamaTransforms;
namespace LLama.WebAPI.Services
{
public class StatelessChatService
{
private readonly LLamaModel _model;
private readonly ChatSession _session;
public StatelessChatService(IConfiguration configuration)
{
_model = new LLamaModel(new ModelParams(configuration["ModelPath"], contextSize: 512));
// TODO: replace with a stateless executor
_session = new ChatSession(new InteractiveExecutor(_model))
.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8))
.WithHistoryTransform(new HistoryTransform());
}
public async Task<string> SendAsync(ChatHistory history)
{
var result = _session.ChatAsync(history, new InferenceParams()
{
AntiPrompts = new string[] { "User:" },
});
var sb = new StringBuilder();
await foreach (var r in result)
{
Console.Write(r);
sb.Append(r);
}
return sb.ToString();
}
}
public class HistoryTransform : DefaultHistoryTransform
{
public override string HistoryToText(ChatHistory history)
{
return base.HistoryToText(history) + "\n Assistant:";
}
}
}

View File

@ -5,5 +5,6 @@
"Microsoft.AspNetCore": "Warning"
}
},
"AllowedHosts": "*"
"AllowedHosts": "*",
"ModelPath": "..\\..\\LLamaModel\\ggml-model-f32-q4_0.bin"
}