LLamaSharp/LLama.Web/Services/ConnectionSessionService.cs

95 lines
4.0 KiB
C#

using LLama.Abstractions;
using LLama.Web.Common;
using LLama.Web.Models;
using Microsoft.Extensions.Options;
using System.Collections.Concurrent;
using System.Drawing;
namespace LLama.Web.Services
{
/// <summary>
/// Example Service for handling a model session for a websockets connection lifetime
/// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc
/// </summary>
public class ConnectionSessionService : IModelSessionService
{
private readonly LLamaOptions _options;
private readonly ILogger<ConnectionSessionService> _logger;
private readonly ConcurrentDictionary<string, ModelSession> _modelSessions;
public ConnectionSessionService(ILogger<ConnectionSessionService> logger, IOptions<LLamaOptions> options)
{
_logger = logger;
_options = options.Value;
_modelSessions = new ConcurrentDictionary<string, ModelSession>();
}
public Task<ModelSession> GetAsync(string connectionId)
{
_modelSessions.TryGetValue(connectionId, out var modelSession);
return Task.FromResult(modelSession);
}
public Task<IServiceResult<ModelSession>> CreateAsync(LLamaExecutorType executorType, string connectionId, string modelName, string promptName, string parameterName)
{
var modelOption = _options.Models.FirstOrDefault(x => x.Name == modelName);
if (modelOption is null)
return Task.FromResult(ServiceResult.FromError<ModelSession>($"Model option '{modelName}' not found"));
var promptOption = _options.Prompts.FirstOrDefault(x => x.Name == promptName);
if (promptOption is null)
return Task.FromResult(ServiceResult.FromError<ModelSession>($"Prompt option '{promptName}' not found"));
var parameterOption = _options.Parameters.FirstOrDefault(x => x.Name == parameterName);
if (parameterOption is null)
return Task.FromResult(ServiceResult.FromError<ModelSession>($"Parameter option '{parameterName}' not found"));
//Max instance
var currentInstances = _modelSessions.Count(x => x.Value.ModelName == modelOption.Name);
if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances)
return Task.FromResult(ServiceResult.FromError<ModelSession>("Maximum model instances reached"));
// Create model
var llamaModel = new LLamaContext(modelOption);
// Create executor
ILLamaExecutor executor = executorType switch
{
LLamaExecutorType.Interactive => new InteractiveExecutor(llamaModel),
LLamaExecutorType.Instruct => new InstructExecutor(llamaModel),
LLamaExecutorType.Stateless => new StatelessExecutor(llamaModel),
_ => default
};
// Create session
var modelSession = new ModelSession(executor, modelOption, promptOption, parameterOption);
if (!_modelSessions.TryAdd(connectionId, modelSession))
return Task.FromResult(ServiceResult.FromError<ModelSession>("Failed to create model session"));
return Task.FromResult(ServiceResult.FromValue(modelSession));
}
public Task<bool> RemoveAsync(string connectionId)
{
if (_modelSessions.TryRemove(connectionId, out var modelSession))
{
modelSession.CancelInfer();
modelSession.Dispose();
return Task.FromResult(true);
}
return Task.FromResult(false);
}
public Task<bool> CancelAsync(string connectionId)
{
if (_modelSessions.TryGetValue(connectionId, out var modelSession))
{
modelSession.CancelInfer();
return Task.FromResult(true);
}
return Task.FromResult(false);
}
}
}