diff --git a/LLama.Web/Common/ISessionConfig.cs b/LLama.Web/Common/ISessionConfig.cs new file mode 100644 index 00000000..09bddc2d --- /dev/null +++ b/LLama.Web/Common/ISessionConfig.cs @@ -0,0 +1,13 @@ +namespace LLama.Web.Common +{ + public interface ISessionConfig + { + string AntiPrompt { get; set; } + List AntiPrompts { get; set; } + LLamaExecutorType ExecutorType { get; set; } + string Model { get; set; } + string OutputFilter { get; set; } + List OutputFilters { get; set; } + string Prompt { get; set; } + } +} \ No newline at end of file diff --git a/LLama.Web/Common/SessionOptions.cs b/LLama.Web/Common/SessionConfig.cs similarity index 89% rename from LLama.Web/Common/SessionOptions.cs rename to LLama.Web/Common/SessionConfig.cs index 34386955..f0a2d22b 100644 --- a/LLama.Web/Common/SessionOptions.cs +++ b/LLama.Web/Common/SessionConfig.cs @@ -1,6 +1,6 @@ namespace LLama.Web.Common { - public class SessionOptions + public class SessionConfig : ISessionConfig { public string Model { get; set; } public string Prompt { get; set; } diff --git a/LLama.Web/Extensioms.cs b/LLama.Web/Extensions.cs similarity index 88% rename from LLama.Web/Extensioms.cs rename to LLama.Web/Extensions.cs index 50bb55c4..99f745dd 100644 --- a/LLama.Web/Extensioms.cs +++ b/LLama.Web/Extensions.cs @@ -2,14 +2,14 @@ namespace LLama.Web { - public static class Extensioms + public static class Extensions { /// /// Combines the AntiPrompts list and AntiPrompt csv /// /// The session configuration. /// Combined AntiPrompts with duplicates removed - public static List GetAntiPrompts(this Common.SessionOptions sessionConfig) + public static List GetAntiPrompts(this ISessionConfig sessionConfig) { return CombineCSV(sessionConfig.AntiPrompts, sessionConfig.AntiPrompt); } @@ -19,7 +19,7 @@ namespace LLama.Web /// /// The session configuration. /// Combined OutputFilters with duplicates removed - public static List GetOutputFilters(this Common.SessionOptions sessionConfig) + public static List GetOutputFilters(this ISessionConfig sessionConfig) { return CombineCSV(sessionConfig.OutputFilters, sessionConfig.OutputFilter); } diff --git a/LLama.Web/Hubs/SessionConnectionHub.cs b/LLama.Web/Hubs/SessionConnectionHub.cs index 730d4e87..24457683 100644 --- a/LLama.Web/Hubs/SessionConnectionHub.cs +++ b/LLama.Web/Hubs/SessionConnectionHub.cs @@ -37,7 +37,7 @@ namespace LLama.Web.Hubs [HubMethodName("LoadModel")] - public async Task OnLoadModel(Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig) + public async Task OnLoadModel(ISessionConfig sessionConfig, InferenceOptions inferenceConfig) { _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); await _modelSessionService.CloseAsync(Context.ConnectionId); diff --git a/LLama.Web/Models/ModelSession.cs b/LLama.Web/Models/ModelSession.cs index 35413f92..91c8920f 100644 --- a/LLama.Web/Models/ModelSession.cs +++ b/LLama.Web/Models/ModelSession.cs @@ -9,21 +9,21 @@ namespace LLama.Web.Models private readonly LLamaModel _model; private readonly LLamaContext _context; private readonly ILLamaExecutor _executor; - private readonly Common.SessionOptions _sessionParams; + private readonly ISessionConfig _sessionConfig; private readonly ITextStreamTransform _outputTransform; private readonly InferenceOptions _defaultInferenceConfig; private CancellationTokenSource _cancellationTokenSource; - public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null) + public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null) { _model = model; _context = context; _sessionId = sessionId; - _sessionParams = sessionOptions; + _sessionConfig = sessionConfig; _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions(); - _outputTransform = CreateOutputFilter(_sessionParams); - _executor = CreateExecutor(_model, _context, _sessionParams); + _outputTransform = CreateOutputFilter(); + _executor = CreateExecutor(); } /// @@ -34,7 +34,7 @@ namespace LLama.Web.Models /// /// Gets the name of the model. /// - public string ModelName => _sessionParams.Model; + public string ModelName => _sessionConfig.Model; /// /// Gets the context. @@ -44,7 +44,7 @@ namespace LLama.Web.Models /// /// Gets the session configuration. /// - public Common.SessionOptions SessionConfig => _sessionParams; + public ISessionConfig SessionConfig => _sessionConfig; /// /// Gets the inference parameters. @@ -60,16 +60,16 @@ namespace LLama.Web.Models /// The cancellation token. internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { - if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless) + if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless) return; - if (string.IsNullOrEmpty(_sessionParams.Prompt)) + if (string.IsNullOrEmpty(_sessionConfig.Prompt)) return; // Run Initial prompt var inferenceParams = ConfigureInferenceParams(inferenceConfig); _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token)) + await foreach (var _ in _executor.InferAsync(_sessionConfig.Prompt, inferenceParams, _cancellationTokenSource.Token)) { // We dont really need the response of the initial prompt, so exit on first token break; @@ -114,13 +114,13 @@ namespace LLama.Web.Models private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig) { var inferenceParams = inferenceConfig ?? _defaultInferenceConfig; - inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts(); + inferenceParams.AntiPrompts = _sessionConfig.GetAntiPrompts(); return inferenceParams; } - private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig) + private ITextStreamTransform CreateOutputFilter() { - var outputFilters = sessionConfig.GetOutputFilters(); + var outputFilters = _sessionConfig.GetOutputFilters(); if (outputFilters.Count > 0) return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters); @@ -128,9 +128,9 @@ namespace LLama.Web.Models } - private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig) + private ILLamaExecutor CreateExecutor() { - return sessionConfig.ExecutorType switch + return _sessionConfig.ExecutorType switch { LLamaExecutorType.Interactive => new InteractiveExecutor(_context), LLamaExecutorType.Instruct => new InstructExecutor(_context), diff --git a/LLama.Web/Pages/Index.cshtml b/LLama.Web/Pages/Index.cshtml index 55512603..3df4b699 100644 --- a/LLama.Web/Pages/Index.cshtml +++ b/LLama.Web/Pages/Index.cshtml @@ -24,11 +24,11 @@
Model - @Html.DropDownListFor(m => m.SessionOptions.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) + @Html.DropDownListFor(m => m.SessionConfig.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
Inference Type - @Html.DropDownListFor(m => m.SessionOptions.ExecutorType, Html.GetEnumSelectList(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) + @Html.DropDownListFor(m => m.SessionConfig.ExecutorType, Html.GetEnumSelectList(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
/// The session identifier. - /// The session configuration. + /// The session configuration. /// The default inference configuration, will be used for all inference where no infer configuration is supplied. /// The cancellation token. /// @@ -33,7 +33,7 @@ namespace LLama.Web.Services /// or /// Failed to create model session /// - Task CreateAsync(string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); + Task CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); /// diff --git a/LLama.Web/Services/ModelSessionService.cs b/LLama.Web/Services/ModelSessionService.cs index e808e630..84070d94 100644 --- a/LLama.Web/Services/ModelSessionService.cs +++ b/LLama.Web/Services/ModelSessionService.cs @@ -65,7 +65,7 @@ namespace LLama.Web.Services /// or /// Failed to create model session /// - public async Task CreateAsync(string sessionId, Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) + public async Task CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { if (_modelSessions.TryGetValue(sessionId, out _)) throw new Exception($"Session with id {sessionId} already exists");