From 44f1b91c292eba68df285005e2e763484a45fc85 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Wed, 4 Oct 2023 12:57:15 +1300 Subject: [PATCH] Update Web to support version 0.5.1 --- LLama.Web/Async/AsyncGuard.cs | 107 +++++++++ LLama.Web/Common/InferenceOptions.cs | 101 ++++++++ LLama.Web/Common/LLamaOptions.cs | 9 - LLama.Web/Common/ParameterOptions.cs | 105 --------- LLama.Web/Common/PromptOptions.cs | 11 - LLama.Web/Common/SessionOptions.cs | 14 ++ LLama.Web/Extensioms.cs | 54 +++++ LLama.Web/Hubs/ISessionClient.cs | 1 - LLama.Web/Hubs/SessionConnectionHub.cs | 57 ++--- LLama.Web/LLama.Web.csproj | 4 + LLama.Web/{ => Models}/LLamaModel.cs | 4 +- LLama.Web/Models/ModelSession.cs | 140 +++++++++--- LLama.Web/Models/ResponseFragment.cs | 18 -- LLama.Web/Models/TokenModel.cs | 24 ++ LLama.Web/Pages/Executor/Instruct.cshtml | 96 -------- LLama.Web/Pages/Executor/Instruct.cshtml.cs | 34 --- LLama.Web/Pages/Executor/Instruct.cshtml.css | 4 - LLama.Web/Pages/Executor/Interactive.cshtml | 96 -------- .../Pages/Executor/Interactive.cshtml.cs | 34 --- .../Pages/Executor/Interactive.cshtml.css | 4 - LLama.Web/Pages/Executor/Stateless.cshtml | 97 -------- LLama.Web/Pages/Executor/Stateless.cshtml.cs | 34 --- LLama.Web/Pages/Executor/Stateless.cshtml.css | 4 - LLama.Web/Pages/Index.cshtml | 119 +++++++++- LLama.Web/Pages/Index.cshtml.cs | 25 +- LLama.Web/Pages/Shared/_ChatTemplates.cshtml | 24 +- LLama.Web/Pages/Shared/_Layout.cshtml | 32 +-- LLama.Web/Pages/Shared/_Parameters.cshtml | 137 +++++++++++ LLama.Web/Program.cs | 5 +- .../Services/ConnectionSessionService.cs | 94 -------- LLama.Web/Services/IModelService.cs | 1 + LLama.Web/Services/IModelSessionService.cs | 84 ++++++- LLama.Web/Services/ModelLoaderService.cs | 42 ++++ LLama.Web/Services/ModelService.cs | 1 + LLama.Web/Services/ModelSessionService.cs | 216 ++++++++++++++++++ LLama.Web/appsettings.json | 60 ++--- LLama.Web/wwwroot/css/site.css | 25 +- LLama.Web/wwwroot/js/sessionConnectionChat.js | 141 ++++++++---- LLama.Web/wwwroot/js/site.js | 8 +- 39 files changed, 1210 insertions(+), 856 deletions(-) create mode 100644 LLama.Web/Async/AsyncGuard.cs create mode 100644 LLama.Web/Common/InferenceOptions.cs delete mode 100644 LLama.Web/Common/ParameterOptions.cs delete mode 100644 LLama.Web/Common/PromptOptions.cs create mode 100644 LLama.Web/Common/SessionOptions.cs create mode 100644 LLama.Web/Extensioms.cs rename LLama.Web/{ => Models}/LLamaModel.cs (98%) delete mode 100644 LLama.Web/Models/ResponseFragment.cs create mode 100644 LLama.Web/Models/TokenModel.cs delete mode 100644 LLama.Web/Pages/Executor/Instruct.cshtml delete mode 100644 LLama.Web/Pages/Executor/Instruct.cshtml.cs delete mode 100644 LLama.Web/Pages/Executor/Instruct.cshtml.css delete mode 100644 LLama.Web/Pages/Executor/Interactive.cshtml delete mode 100644 LLama.Web/Pages/Executor/Interactive.cshtml.cs delete mode 100644 LLama.Web/Pages/Executor/Interactive.cshtml.css delete mode 100644 LLama.Web/Pages/Executor/Stateless.cshtml delete mode 100644 LLama.Web/Pages/Executor/Stateless.cshtml.cs delete mode 100644 LLama.Web/Pages/Executor/Stateless.cshtml.css create mode 100644 LLama.Web/Pages/Shared/_Parameters.cshtml delete mode 100644 LLama.Web/Services/ConnectionSessionService.cs create mode 100644 LLama.Web/Services/ModelLoaderService.cs create mode 100644 LLama.Web/Services/ModelSessionService.cs diff --git a/LLama.Web/Async/AsyncGuard.cs b/LLama.Web/Async/AsyncGuard.cs new file mode 100644 index 00000000..ff6b6c43 --- /dev/null +++ b/LLama.Web/Async/AsyncGuard.cs @@ -0,0 +1,107 @@ +using System.Collections.Concurrent; + +namespace LLama.Web.Async +{ + + /// + /// Creates a async/thread-safe guard helper + /// + /// + public class AsyncGuard : AsyncGuard + { + private readonly byte _key; + private readonly ConcurrentDictionary _lockData; + + + /// + /// Initializes a new instance of the class. + /// + public AsyncGuard() + { + _key = 0; + _lockData = new ConcurrentDictionary(); + } + + + /// + /// Guards this instance. + /// + /// true if able to enter an guard, false if already guarded + public bool Guard() + { + return _lockData.TryAdd(_key, true); + } + + + /// + /// Releases the guard. + /// + /// + public bool Release() + { + return _lockData.TryRemove(_key, out _); + } + + + /// + /// Determines whether this instance is guarded. + /// + /// + /// true if this instance is guarded; otherwise, false. + /// + public bool IsGuarded() + { + return _lockData.ContainsKey(_key); + } + } + + + public class AsyncGuard + { + private readonly ConcurrentDictionary _lockData; + + + /// + /// Initializes a new instance of the class. + /// + public AsyncGuard() + { + _lockData = new ConcurrentDictionary(); + } + + + /// + /// Guards the specified value. + /// + /// The value. + /// true if able to enter a guard for this value, false if this value is already guarded + public bool Guard(T value) + { + return _lockData.TryAdd(value, true); + } + + + /// + /// Releases the guard on the specified value. + /// + /// The value. + /// + public bool Release(T value) + { + return _lockData.TryRemove(value, out _); + } + + + /// + /// Determines whether the specified value is guarded. + /// + /// The value. + /// + /// true if the specified value is guarded; otherwise, false. + /// + public bool IsGuarded(T value) + { + return _lockData.ContainsKey(value); + } + } +} diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs new file mode 100644 index 00000000..c2420af3 --- /dev/null +++ b/LLama.Web/Common/InferenceOptions.cs @@ -0,0 +1,101 @@ +using LLama.Common; +using LLama.Abstractions; +using LLama.Native; + +namespace LLama.Web.Common +{ + public class InferenceOptions : IInferenceParams + { + /// + /// number of tokens to keep from initial prompt + /// + public int TokensKeep { get; set; } = 0; + /// + /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response + /// until it complete. + /// + public int MaxTokens { get; set; } = -1; + /// + /// logit bias for specific tokens + /// + public Dictionary? LogitBias { get; set; } = null; + + /// + /// Sequences where the model will stop generating further tokens. + /// + public IEnumerable AntiPrompts { get; set; } = Array.Empty(); + /// + /// path to file for saving/loading model eval state + /// + public string PathSession { get; set; } = string.Empty; + /// + /// string to suffix user inputs with + /// + public string InputSuffix { get; set; } = string.Empty; + /// + /// string to prefix user inputs with + /// + public string InputPrefix { get; set; } = string.Empty; + /// + /// 0 or lower to use vocab size + /// + public int TopK { get; set; } = 40; + /// + /// 1.0 = disabled + /// + public float TopP { get; set; } = 0.95f; + /// + /// 1.0 = disabled + /// + public float TfsZ { get; set; } = 1.0f; + /// + /// 1.0 = disabled + /// + public float TypicalP { get; set; } = 1.0f; + /// + /// 1.0 = disabled + /// + public float Temperature { get; set; } = 0.8f; + /// + /// 1.0 = disabled + /// + public float RepeatPenalty { get; set; } = 1.1f; + /// + /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) + /// + public int RepeatLastTokensCount { get; set; } = 64; + /// + /// frequency penalty coefficient + /// 0.0 = disabled + /// + public float FrequencyPenalty { get; set; } = .0f; + /// + /// presence penalty coefficient + /// 0.0 = disabled + /// + public float PresencePenalty { get; set; } = .0f; + /// + /// Mirostat uses tokens instead of words. + /// algorithm described in the paper https://arxiv.org/abs/2007.14966. + /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + /// + public MirostatType Mirostat { get; set; } = MirostatType.Disable; + /// + /// target entropy + /// + public float MirostatTau { get; set; } = 5.0f; + /// + /// learning rate + /// + public float MirostatEta { get; set; } = 0.1f; + /// + /// consider newlines as a repeatable token (penalize_nl) + /// + public bool PenalizeNL { get; set; } = true; + + /// + /// A grammar to constrain possible tokens + /// + public SafeLLamaGrammarHandle Grammar { get; set; } = null; + } +} diff --git a/LLama.Web/Common/LLamaOptions.cs b/LLama.Web/Common/LLamaOptions.cs index a64b9635..4a1d6e0a 100644 --- a/LLama.Web/Common/LLamaOptions.cs +++ b/LLama.Web/Common/LLamaOptions.cs @@ -4,18 +4,9 @@ { public ModelLoadType ModelLoadType { get; set; } public List Models { get; set; } - public List Prompts { get; set; } = new List(); - public List Parameters { get; set; } = new List(); public void Initialize() { - foreach (var prompt in Prompts) - { - if (File.Exists(prompt.Path)) - { - prompt.Prompt = File.ReadAllText(prompt.Path).Trim(); - } - } } } } diff --git a/LLama.Web/Common/ParameterOptions.cs b/LLama.Web/Common/ParameterOptions.cs deleted file mode 100644 index f78aa861..00000000 --- a/LLama.Web/Common/ParameterOptions.cs +++ /dev/null @@ -1,105 +0,0 @@ -using LLama.Common; -using LLama.Abstractions; -using LLama.Native; - -namespace LLama.Web.Common -{ - public class ParameterOptions : IInferenceParams - { - public string Name { get; set; } - - - - /// - /// number of tokens to keep from initial prompt - /// - public int TokensKeep { get; set; } = 0; - /// - /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response - /// until it complete. - /// - public int MaxTokens { get; set; } = -1; - /// - /// logit bias for specific tokens - /// - public Dictionary? LogitBias { get; set; } = null; - - /// - /// Sequences where the model will stop generating further tokens. - /// - public IEnumerable AntiPrompts { get; set; } = Array.Empty(); - /// - /// path to file for saving/loading model eval state - /// - public string PathSession { get; set; } = string.Empty; - /// - /// string to suffix user inputs with - /// - public string InputSuffix { get; set; } = string.Empty; - /// - /// string to prefix user inputs with - /// - public string InputPrefix { get; set; } = string.Empty; - /// - /// 0 or lower to use vocab size - /// - public int TopK { get; set; } = 40; - /// - /// 1.0 = disabled - /// - public float TopP { get; set; } = 0.95f; - /// - /// 1.0 = disabled - /// - public float TfsZ { get; set; } = 1.0f; - /// - /// 1.0 = disabled - /// - public float TypicalP { get; set; } = 1.0f; - /// - /// 1.0 = disabled - /// - public float Temperature { get; set; } = 0.8f; - /// - /// 1.0 = disabled - /// - public float RepeatPenalty { get; set; } = 1.1f; - /// - /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) - /// - public int RepeatLastTokensCount { get; set; } = 64; - /// - /// frequency penalty coefficient - /// 0.0 = disabled - /// - public float FrequencyPenalty { get; set; } = .0f; - /// - /// presence penalty coefficient - /// 0.0 = disabled - /// - public float PresencePenalty { get; set; } = .0f; - /// - /// Mirostat uses tokens instead of words. - /// algorithm described in the paper https://arxiv.org/abs/2007.14966. - /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - /// - public MirostatType Mirostat { get; set; } = MirostatType.Disable; - /// - /// target entropy - /// - public float MirostatTau { get; set; } = 5.0f; - /// - /// learning rate - /// - public float MirostatEta { get; set; } = 0.1f; - /// - /// consider newlines as a repeatable token (penalize_nl) - /// - public bool PenalizeNL { get; set; } = true; - - /// - /// A grammar to constrain possible tokens - /// - public SafeLLamaGrammarHandle Grammar { get; set; } = null; - } -} diff --git a/LLama.Web/Common/PromptOptions.cs b/LLama.Web/Common/PromptOptions.cs deleted file mode 100644 index 4e44a5d1..00000000 --- a/LLama.Web/Common/PromptOptions.cs +++ /dev/null @@ -1,11 +0,0 @@ -namespace LLama.Web.Common -{ - public class PromptOptions - { - public string Name { get; set; } - public string Path { get; set; } - public string Prompt { get; set; } - public List AntiPrompt { get; set; } - public List OutputFilter { get; set; } - } -} diff --git a/LLama.Web/Common/SessionOptions.cs b/LLama.Web/Common/SessionOptions.cs new file mode 100644 index 00000000..34386955 --- /dev/null +++ b/LLama.Web/Common/SessionOptions.cs @@ -0,0 +1,14 @@ +namespace LLama.Web.Common +{ + public class SessionOptions + { + public string Model { get; set; } + public string Prompt { get; set; } + + public string AntiPrompt { get; set; } + public List AntiPrompts { get; set; } + public string OutputFilter { get; set; } + public List OutputFilters { get; set; } + public LLamaExecutorType ExecutorType { get; set; } + } +} diff --git a/LLama.Web/Extensioms.cs b/LLama.Web/Extensioms.cs new file mode 100644 index 00000000..50bb55c4 --- /dev/null +++ b/LLama.Web/Extensioms.cs @@ -0,0 +1,54 @@ +using LLama.Web.Common; + +namespace LLama.Web +{ + public static class Extensioms + { + /// + /// Combines the AntiPrompts list and AntiPrompt csv + /// + /// The session configuration. + /// Combined AntiPrompts with duplicates removed + public static List GetAntiPrompts(this Common.SessionOptions sessionConfig) + { + return CombineCSV(sessionConfig.AntiPrompts, sessionConfig.AntiPrompt); + } + + /// + /// Combines the OutputFilters list and OutputFilter csv + /// + /// The session configuration. + /// Combined OutputFilters with duplicates removed + public static List GetOutputFilters(this Common.SessionOptions sessionConfig) + { + return CombineCSV(sessionConfig.OutputFilters, sessionConfig.OutputFilter); + } + + + /// + /// Combines a string list and a csv and removes duplicates + /// + /// The list. + /// The CSV. + /// Combined list with duplicates removed + private static List CombineCSV(List list, string csv) + { + var results = list?.Count == 0 + ? CommaSeperatedToList(csv) + : CommaSeperatedToList(csv).Concat(list); + return results + .Distinct() + .ToList(); + } + + private static List CommaSeperatedToList(string value) + { + if (string.IsNullOrEmpty(value)) + return new List(); + + return value.Split(",", StringSplitOptions.RemoveEmptyEntries) + .Select(x => x.Trim()) + .ToList(); + } + } +} diff --git a/LLama.Web/Hubs/ISessionClient.cs b/LLama.Web/Hubs/ISessionClient.cs index 9e9dc0f1..92302b21 100644 --- a/LLama.Web/Hubs/ISessionClient.cs +++ b/LLama.Web/Hubs/ISessionClient.cs @@ -6,7 +6,6 @@ namespace LLama.Web.Hubs public interface ISessionClient { Task OnStatus(string connectionId, SessionConnectionStatus status); - Task OnResponse(ResponseFragment fragment); Task OnError(string error); } } diff --git a/LLama.Web/Hubs/SessionConnectionHub.cs b/LLama.Web/Hubs/SessionConnectionHub.cs index 080866c6..730d4e87 100644 --- a/LLama.Web/Hubs/SessionConnectionHub.cs +++ b/LLama.Web/Hubs/SessionConnectionHub.cs @@ -2,16 +2,15 @@ using LLama.Web.Models; using LLama.Web.Services; using Microsoft.AspNetCore.SignalR; -using System.Diagnostics; namespace LLama.Web.Hubs { public class SessionConnectionHub : Hub { private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; + private readonly IModelSessionService _modelSessionService; - public SessionConnectionHub(ILogger logger, ConnectionSessionService modelSessionService) + public SessionConnectionHub(ILogger logger, IModelSessionService modelSessionService) { _logger = logger; _modelSessionService = modelSessionService; @@ -27,29 +26,27 @@ namespace LLama.Web.Hubs } - public override async Task OnDisconnectedAsync(Exception? exception) + public override async Task OnDisconnectedAsync(Exception exception) { _logger.Log(LogLevel.Information, "[OnDisconnectedAsync], Id: {0}", Context.ConnectionId); // Remove connections session on dissconnect - await _modelSessionService.RemoveAsync(Context.ConnectionId); + await _modelSessionService.CloseAsync(Context.ConnectionId); await base.OnDisconnectedAsync(exception); } [HubMethodName("LoadModel")] - public async Task OnLoadModel(LLamaExecutorType executorType, string modelName, string promptName, string parameterName) + public async Task OnLoadModel(Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig) { - _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}, Model: {1}, Prompt: {2}, Parameter: {3}", Context.ConnectionId, modelName, promptName, parameterName); - - // Remove existing connections session - await _modelSessionService.RemoveAsync(Context.ConnectionId); + _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); + await _modelSessionService.CloseAsync(Context.ConnectionId); // Create model session - var modelSessionResult = await _modelSessionService.CreateAsync(executorType, Context.ConnectionId, modelName, promptName, parameterName); - if (modelSessionResult.HasError) + var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, sessionConfig, inferenceConfig); + if (modelSession is null) { - await Clients.Caller.OnError(modelSessionResult.Error); + await Clients.Caller.OnError("Failed to create model session"); return; } @@ -59,40 +56,12 @@ namespace LLama.Web.Hubs [HubMethodName("SendPrompt")] - public async Task OnSendPrompt(string prompt) + public IAsyncEnumerable OnSendPrompt(string prompt, InferenceOptions inferConfig, CancellationToken cancellationToken) { _logger.Log(LogLevel.Information, "[OnSendPrompt] - New prompt received, Connection: {0}", Context.ConnectionId); - // Get connections session - var modelSession = await _modelSessionService.GetAsync(Context.ConnectionId); - if (modelSession is null) - { - await Clients.Caller.OnError("No model has been loaded"); - return; - } - - - // Create unique response id - var responseId = Guid.NewGuid().ToString(); - - // Send begin of response - await Clients.Caller.OnResponse(new ResponseFragment(responseId, isFirst: true)); - - // Send content of response - var stopwatch = Stopwatch.GetTimestamp(); - await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted))) - { - await Clients.Caller.OnResponse(new ResponseFragment(responseId, fragment)); - } - - // Send end of response - var elapsedTime = Stopwatch.GetElapsedTime(stopwatch); - var signature = modelSession.IsInferCanceled() - ? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds" - : $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds"; - await Clients.Caller.OnResponse(new ResponseFragment(responseId, signature, isLast: true)); - _logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled()); + var linkedCancelationToken = CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted, cancellationToken); + return _modelSessionService.InferAsync(Context.ConnectionId, prompt, inferConfig, linkedCancelationToken.Token); } - } } diff --git a/LLama.Web/LLama.Web.csproj b/LLama.Web/LLama.Web.csproj index d0e15a62..5a46c5e8 100644 --- a/LLama.Web/LLama.Web.csproj +++ b/LLama.Web/LLama.Web.csproj @@ -14,4 +14,8 @@ + + + + diff --git a/LLama.Web/LLamaModel.cs b/LLama.Web/Models/LLamaModel.cs similarity index 98% rename from LLama.Web/LLamaModel.cs rename to LLama.Web/Models/LLamaModel.cs index e500ba04..71bb290e 100644 --- a/LLama.Web/LLamaModel.cs +++ b/LLama.Web/Models/LLamaModel.cs @@ -2,12 +2,12 @@ using LLama.Web.Common; using System.Collections.Concurrent; -namespace LLama.Web +namespace LLama.Web.Models { /// /// Wrapper class for LLamaSharp LLamaWeights /// - /// + /// public class LLamaModel : IDisposable { private readonly ModelOptions _config; diff --git a/LLama.Web/Models/ModelSession.cs b/LLama.Web/Models/ModelSession.cs index c53676f2..35413f92 100644 --- a/LLama.Web/Models/ModelSession.cs +++ b/LLama.Web/Models/ModelSession.cs @@ -3,46 +3,97 @@ using LLama.Web.Common; namespace LLama.Web.Models { - public class ModelSession : IDisposable + public class ModelSession { - private bool _isFirstInteraction = true; - private ModelOptions _modelOptions; - private PromptOptions _promptOptions; - private ParameterOptions _inferenceOptions; - private ITextStreamTransform _outputTransform; - private ILLamaExecutor _executor; + private readonly string _sessionId; + private readonly LLamaModel _model; + private readonly LLamaContext _context; + private readonly ILLamaExecutor _executor; + private readonly Common.SessionOptions _sessionParams; + private readonly ITextStreamTransform _outputTransform; + private readonly InferenceOptions _defaultInferenceConfig; + private CancellationTokenSource _cancellationTokenSource; - public ModelSession(ILLamaExecutor executor, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions) + public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null) { - _executor = executor; - _modelOptions = modelOptions; - _promptOptions = promptOptions; - _inferenceOptions = parameterOptions; - - _inferenceOptions.AntiPrompts = _promptOptions.AntiPrompt?.Concat(_inferenceOptions.AntiPrompts ?? Enumerable.Empty()).Distinct() ?? _inferenceOptions.AntiPrompts; - if (_promptOptions.OutputFilter?.Count > 0) - _outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptOptions.OutputFilter, redundancyLength: 5); + _model = model; + _context = context; + _sessionId = sessionId; + _sessionParams = sessionOptions; + _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions(); + _outputTransform = CreateOutputFilter(_sessionParams); + _executor = CreateExecutor(_model, _context, _sessionParams); } - public string ModelName - { - get { return _modelOptions.Name; } - } + /// + /// Gets the session identifier. + /// + public string SessionId => _sessionId; - public IAsyncEnumerable InferAsync(string message, CancellationTokenSource cancellationTokenSource) + /// + /// Gets the name of the model. + /// + public string ModelName => _sessionParams.Model; + + /// + /// Gets the context. + /// + public LLamaContext Context => _context; + + /// + /// Gets the session configuration. + /// + public Common.SessionOptions SessionConfig => _sessionParams; + + /// + /// Gets the inference parameters. + /// + public InferenceOptions InferenceParams => _defaultInferenceConfig; + + + + /// + /// Initializes the prompt. + /// + /// The inference configuration. + /// The cancellation token. + internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { - _cancellationTokenSource = cancellationTokenSource; - if (_isFirstInteraction) + if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless) + return; + + if (string.IsNullOrEmpty(_sessionParams.Prompt)) + return; + + // Run Initial prompt + var inferenceParams = ConfigureInferenceParams(inferenceConfig); + _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token)) { - _isFirstInteraction = false; - message = _promptOptions.Prompt + message; - } + // We dont really need the response of the initial prompt, so exit on first token + break; + }; + } + + /// + /// Runs inference on the model context + /// + /// The message. + /// The inference configuration. + /// The cancellation token. + /// + internal IAsyncEnumerable InferAsync(string message, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) + { + var inferenceParams = ConfigureInferenceParams(inferenceConfig); + _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + + var inferenceStream = _executor.InferAsync(message, inferenceParams, _cancellationTokenSource.Token); if (_outputTransform is not null) - return _outputTransform.TransformAsync(_executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token)); + return _outputTransform.TransformAsync(inferenceStream); - return _executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token); + return inferenceStream; } @@ -56,13 +107,36 @@ namespace LLama.Web.Models return _cancellationTokenSource.IsCancellationRequested; } - public void Dispose() + /// + /// Configures the inference parameters. + /// + /// The inference configuration. + private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig) { - _inferenceOptions = null; - _outputTransform = null; + var inferenceParams = inferenceConfig ?? _defaultInferenceConfig; + inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts(); + return inferenceParams; + } - _executor?.Context.Dispose(); - _executor = null; + private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig) + { + var outputFilters = sessionConfig.GetOutputFilters(); + if (outputFilters.Count > 0) + return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters); + + return null; + } + + + private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig) + { + return sessionConfig.ExecutorType switch + { + LLamaExecutorType.Interactive => new InteractiveExecutor(_context), + LLamaExecutorType.Instruct => new InstructExecutor(_context), + LLamaExecutorType.Stateless => new StatelessExecutor(_model.LLamaWeights, _model.ModelParams), + _ => default + }; } } } diff --git a/LLama.Web/Models/ResponseFragment.cs b/LLama.Web/Models/ResponseFragment.cs deleted file mode 100644 index 02f27f13..00000000 --- a/LLama.Web/Models/ResponseFragment.cs +++ /dev/null @@ -1,18 +0,0 @@ -namespace LLama.Web.Models -{ - public class ResponseFragment - { - public ResponseFragment(string id, string content = null, bool isFirst = false, bool isLast = false) - { - Id = id; - IsLast = isLast; - IsFirst = isFirst; - Content = content; - } - - public string Id { get; set; } - public string Content { get; set; } - public bool IsLast { get; set; } - public bool IsFirst { get; set; } - } -} diff --git a/LLama.Web/Models/TokenModel.cs b/LLama.Web/Models/TokenModel.cs new file mode 100644 index 00000000..c95f9ec6 --- /dev/null +++ b/LLama.Web/Models/TokenModel.cs @@ -0,0 +1,24 @@ +namespace LLama.Web.Models +{ + public class TokenModel + { + public TokenModel(string id, string content = null, TokenType tokenType = TokenType.Content) + { + Id = id; + Content = content; + TokenType = tokenType; + } + + public string Id { get; set; } + public string Content { get; set; } + public TokenType TokenType { get; set; } + } + + public enum TokenType + { + Begin = 0, + Content = 2, + End = 4, + Cancel = 10 + } +} diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml b/LLama.Web/Pages/Executor/Instruct.cshtml deleted file mode 100644 index 9f8cb2d8..00000000 --- a/LLama.Web/Pages/Executor/Instruct.cshtml +++ /dev/null @@ -1,96 +0,0 @@ -@page -@model InstructModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Instruct

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
-
-
- -
-
- -
-
-
- -
-
-
-
- -
-
- -
-
- -
-
- - -
-
-
-
- -
-
- -@{ await Html.RenderPartialAsync("_ChatTemplates"); } - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml.cs b/LLama.Web/Pages/Executor/Instruct.cshtml.cs deleted file mode 100644 index 18a58253..00000000 --- a/LLama.Web/Pages/Executor/Instruct.cshtml.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.Web.Common; -using LLama.Web.Models; -using LLama.Web.Services; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class InstructModel : PageModel - { - private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; - - public InstructModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) - { - _logger = logger; - Options = options.Value; - _modelSessionService = modelSessionService; - } - - public LLamaOptions Options { get; set; } - - public void OnGet() - { - } - - public async Task OnPostCancel(CancelModel model) - { - await _modelSessionService.CancelAsync(model.ConnectionId); - return new JsonResult(default); - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml.css b/LLama.Web/Pages/Executor/Instruct.cshtml.css deleted file mode 100644 index ed9a1d59..00000000 --- a/LLama.Web/Pages/Executor/Instruct.cshtml.css +++ /dev/null @@ -1,4 +0,0 @@ -.section-content { - flex: 1; - overflow-y: scroll; -} diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml b/LLama.Web/Pages/Executor/Interactive.cshtml deleted file mode 100644 index 916b59ca..00000000 --- a/LLama.Web/Pages/Executor/Interactive.cshtml +++ /dev/null @@ -1,96 +0,0 @@ -@page -@model InteractiveModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Interactive

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
-
-
- -
-
- -
-
-
- -
-
-
-
- -
-
- -
-
- -
-
- - -
-
-
-
- -
-
- -@{ await Html.RenderPartialAsync("_ChatTemplates");} - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml.cs b/LLama.Web/Pages/Executor/Interactive.cshtml.cs deleted file mode 100644 index 7179a440..00000000 --- a/LLama.Web/Pages/Executor/Interactive.cshtml.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.Web.Common; -using LLama.Web.Models; -using LLama.Web.Services; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class InteractiveModel : PageModel - { - private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; - - public InteractiveModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) - { - _logger = logger; - Options = options.Value; - _modelSessionService = modelSessionService; - } - - public LLamaOptions Options { get; set; } - - public void OnGet() - { - } - - public async Task OnPostCancel(CancelModel model) - { - await _modelSessionService.CancelAsync(model.ConnectionId); - return new JsonResult(default); - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml.css b/LLama.Web/Pages/Executor/Interactive.cshtml.css deleted file mode 100644 index ed9a1d59..00000000 --- a/LLama.Web/Pages/Executor/Interactive.cshtml.css +++ /dev/null @@ -1,4 +0,0 @@ -.section-content { - flex: 1; - overflow-y: scroll; -} diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml b/LLama.Web/Pages/Executor/Stateless.cshtml deleted file mode 100644 index b5d8eea3..00000000 --- a/LLama.Web/Pages/Executor/Stateless.cshtml +++ /dev/null @@ -1,97 +0,0 @@ -@page -@model StatelessModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Stateless

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
-
-
- -
-
- -
-
-
- -
-
-
-
- -
-
- -
-
- -
-
- - -
-
-
-
- -
-
- -@{ await Html.RenderPartialAsync("_ChatTemplates"); } - - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml.cs b/LLama.Web/Pages/Executor/Stateless.cshtml.cs deleted file mode 100644 index f88c4b83..00000000 --- a/LLama.Web/Pages/Executor/Stateless.cshtml.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.Web.Common; -using LLama.Web.Models; -using LLama.Web.Services; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class StatelessModel : PageModel - { - private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; - - public StatelessModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) - { - _logger = logger; - Options = options.Value; - _modelSessionService = modelSessionService; - } - - public LLamaOptions Options { get; set; } - - public void OnGet() - { - } - - public async Task OnPostCancel(CancelModel model) - { - await _modelSessionService.CancelAsync(model.ConnectionId); - return new JsonResult(default); - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml.css b/LLama.Web/Pages/Executor/Stateless.cshtml.css deleted file mode 100644 index ed9a1d59..00000000 --- a/LLama.Web/Pages/Executor/Stateless.cshtml.css +++ /dev/null @@ -1,4 +0,0 @@ -.section-content { - flex: 1; - overflow-y: scroll; -} diff --git a/LLama.Web/Pages/Index.cshtml b/LLama.Web/Pages/Index.cshtml index b5f0c15f..55512603 100644 --- a/LLama.Web/Pages/Index.cshtml +++ b/LLama.Web/Pages/Index.cshtml @@ -1,10 +1,121 @@ @page +@using LLama.Web.Common; + @model IndexModel @{ - ViewData["Title"] = "Home page"; + ViewData["Title"] = "Inference Demo"; } -
-

Welcome

-

Learn about building Web apps with ASP.NET Core.

+@Html.AntiForgeryToken() +
+ +
+
+
+ @ViewData["Title"] +
+
+ Socket: Disconnected +
+
+ +
+
+
+
+ Model + @Html.DropDownListFor(m => m.SessionOptions.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) +
+
+ Inference Type + @Html.DropDownListFor(m => m.SessionOptions.ExecutorType, Html.GetEnumSelectList(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) +
+ + +
+
+
+ +
+
+
+ + +
+
+ +
+
+
+ +
+
+
+
+ +
+
+ +
+
+ +
+
+ + +
+
+
+
+ +
+ +@{ + await Html.RenderPartialAsync("_ChatTemplates"); +} + +@section Scripts { + + +} \ No newline at end of file diff --git a/LLama.Web/Pages/Index.cshtml.cs b/LLama.Web/Pages/Index.cshtml.cs index 477c9bfb..3647dfec 100644 --- a/LLama.Web/Pages/Index.cshtml.cs +++ b/LLama.Web/Pages/Index.cshtml.cs @@ -1,5 +1,7 @@ -using Microsoft.AspNetCore.Mvc; +using LLama.Web.Common; +using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.RazorPages; +using Microsoft.Extensions.Options; namespace LLama.Web.Pages { @@ -7,14 +9,33 @@ namespace LLama.Web.Pages { private readonly ILogger _logger; - public IndexModel(ILogger logger) + public IndexModel(ILogger logger, IOptions options) { _logger = logger; + Options = options.Value; } + public LLamaOptions Options { get; set; } + + [BindProperty] + public Common.SessionOptions SessionOptions { get; set; } + + [BindProperty] + public InferenceOptions InferenceOptions { get; set; } + public void OnGet() { + SessionOptions = new Common.SessionOptions + { + Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.", + AntiPrompt = "User:", + // OutputFilter = "User:, Response:" + }; + InferenceOptions = new InferenceOptions + { + Temperature = 0.8f + }; } } } \ No newline at end of file diff --git a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml index 15644012..cd768f1f 100644 --- a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml +++ b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml @@ -12,7 +12,7 @@
- {{text}} + {{text}}
{{date}}
@@ -26,9 +26,7 @@
- - - +
@@ -41,20 +39,6 @@
- \ No newline at end of file diff --git a/LLama.Web/Pages/Shared/_Layout.cshtml b/LLama.Web/Pages/Shared/_Layout.cshtml index 23132bfa..16d6ad52 100644 --- a/LLama.Web/Pages/Shared/_Layout.cshtml +++ b/LLama.Web/Pages/Shared/_Layout.cshtml @@ -3,7 +3,7 @@ - @ViewData["Title"] - LLama.Web + @ViewData["Title"] - LLamaSharp.Web @@ -13,24 +13,26 @@
-
- @RenderBody() -
+
+ @RenderBody() +
- © 2023 - LLama.Web + © 2023 - LLamaSharp.Web
diff --git a/LLama.Web/Pages/Shared/_Parameters.cshtml b/LLama.Web/Pages/Shared/_Parameters.cshtml new file mode 100644 index 00000000..d6e476c4 --- /dev/null +++ b/LLama.Web/Pages/Shared/_Parameters.cshtml @@ -0,0 +1,137 @@ +@page +@using LLama.Common; +@model LLama.Abstractions.IInferenceParams +} + +
+
+ MaxTokens +
+ @Html.TextBoxFor(m => m.MaxTokens, new { @type="range", @class = "slider", min="-1", max="2048", step="1" }) + +
+
+ +
+ TokensKeep +
+ @Html.TextBoxFor(m => m.TokensKeep, new { @type="range", @class = "slider", min="0", max="2048", step="1" }) + +
+
+
+ +
+
+ TopK +
+ @Html.TextBoxFor(m => m.TopK, new { @type="range", @class = "slider", min="-1", max="100", step="1" }) + +
+
+ +
+ TopP +
+ @Html.TextBoxFor(m => m.TopP, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+
+ + + +
+
+ TypicalP +
+ @Html.TextBoxFor(m => m.TypicalP, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+ +
+ Temperature +
+ @Html.TextBoxFor(m => m.Temperature, new { @type="range", @class = "slider", min="0.0", max="1.5", step="0.01" }) + +
+
+
+ +
+
+ RepeatPenalty +
+ @Html.TextBoxFor(m => m.RepeatPenalty, new { @type="range", @class = "slider", min="0.0", max="2.0", step="0.01" }) + +
+
+ +
+ RepeatLastTokensCount +
+ @Html.TextBoxFor(m => m.RepeatLastTokensCount, new { @type="range", @class = "slider", min="0", max="2048", step="1" }) + +
+
+
+ +
+
+ FrequencyPenalty +
+ @Html.TextBoxFor(m => m.FrequencyPenalty, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+ +
+ PresencePenalty +
+ @Html.TextBoxFor(m => m.PresencePenalty, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+
+ +
+
+ TfsZ +
+ @Html.TextBoxFor(m => m.TfsZ, new { @type="range", @class = "slider",min="0.0", max="1.0", step="0.01" }) + +
+
+
+ - +
+ + +
+
+
+ + +
+ Sampler Type + @Html.DropDownListFor(m => m.Mirostat, Html.GetEnumSelectList(), new { @class = "form-control form-select" }) +
+ +
+
+ MirostatTau +
+ @Html.TextBoxFor(m => m.MirostatTau, new { @type="range", @class = "slider", min="0.0", max="10.0", step="0.01" }) + +
+
+ +
+ MirostatEta +
+ @Html.TextBoxFor(m => m.MirostatEta, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+
\ No newline at end of file diff --git a/LLama.Web/Program.cs b/LLama.Web/Program.cs index 6db653a1..7c4583d2 100644 --- a/LLama.Web/Program.cs +++ b/LLama.Web/Program.cs @@ -1,6 +1,7 @@ using LLama.Web.Common; using LLama.Web.Hubs; using LLama.Web.Services; +using Microsoft.Extensions.DependencyInjection; namespace LLama.Web { @@ -20,7 +21,9 @@ namespace LLama.Web .BindConfiguration(nameof(LLamaOptions)); // Services DI - builder.Services.AddSingleton(); + builder.Services.AddHostedService(); + builder.Services.AddSingleton(); + builder.Services.AddSingleton(); var app = builder.Build(); diff --git a/LLama.Web/Services/ConnectionSessionService.cs b/LLama.Web/Services/ConnectionSessionService.cs deleted file mode 100644 index 7dfcde39..00000000 --- a/LLama.Web/Services/ConnectionSessionService.cs +++ /dev/null @@ -1,94 +0,0 @@ -using LLama.Abstractions; -using LLama.Web.Common; -using LLama.Web.Models; -using Microsoft.Extensions.Options; -using System.Collections.Concurrent; -using System.Drawing; - -namespace LLama.Web.Services -{ - /// - /// Example Service for handling a model session for a websockets connection lifetime - /// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc - /// - public class ConnectionSessionService : IModelSessionService - { - private readonly LLamaOptions _options; - private readonly ILogger _logger; - private readonly ConcurrentDictionary _modelSessions; - - public ConnectionSessionService(ILogger logger, IOptions options) - { - _logger = logger; - _options = options.Value; - _modelSessions = new ConcurrentDictionary(); - } - - public Task GetAsync(string connectionId) - { - _modelSessions.TryGetValue(connectionId, out var modelSession); - return Task.FromResult(modelSession); - } - - public Task> CreateAsync(LLamaExecutorType executorType, string connectionId, string modelName, string promptName, string parameterName) - { - var modelOption = _options.Models.FirstOrDefault(x => x.Name == modelName); - if (modelOption is null) - return Task.FromResult(ServiceResult.FromError($"Model option '{modelName}' not found")); - - var promptOption = _options.Prompts.FirstOrDefault(x => x.Name == promptName); - if (promptOption is null) - return Task.FromResult(ServiceResult.FromError($"Prompt option '{promptName}' not found")); - - var parameterOption = _options.Parameters.FirstOrDefault(x => x.Name == parameterName); - if (parameterOption is null) - return Task.FromResult(ServiceResult.FromError($"Parameter option '{parameterName}' not found")); - - - //Max instance - var currentInstances = _modelSessions.Count(x => x.Value.ModelName == modelOption.Name); - if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances) - return Task.FromResult(ServiceResult.FromError("Maximum model instances reached")); - - // Create model - var llamaModel = new LLamaContext(modelOption); - - // Create executor - ILLamaExecutor executor = executorType switch - { - LLamaExecutorType.Interactive => new InteractiveExecutor(llamaModel), - LLamaExecutorType.Instruct => new InstructExecutor(llamaModel), - LLamaExecutorType.Stateless => new StatelessExecutor(llamaModel), - _ => default - }; - - // Create session - var modelSession = new ModelSession(executor, modelOption, promptOption, parameterOption); - if (!_modelSessions.TryAdd(connectionId, modelSession)) - return Task.FromResult(ServiceResult.FromError("Failed to create model session")); - - return Task.FromResult(ServiceResult.FromValue(modelSession)); - } - - public Task RemoveAsync(string connectionId) - { - if (_modelSessions.TryRemove(connectionId, out var modelSession)) - { - modelSession.CancelInfer(); - modelSession.Dispose(); - return Task.FromResult(true); - } - return Task.FromResult(false); - } - - public Task CancelAsync(string connectionId) - { - if (_modelSessions.TryGetValue(connectionId, out var modelSession)) - { - modelSession.CancelInfer(); - return Task.FromResult(true); - } - return Task.FromResult(false); - } - } -} diff --git a/LLama.Web/Services/IModelService.cs b/LLama.Web/Services/IModelService.cs index 0a98f8f4..ec9e4233 100644 --- a/LLama.Web/Services/IModelService.cs +++ b/LLama.Web/Services/IModelService.cs @@ -1,4 +1,5 @@ using LLama.Web.Common; +using LLama.Web.Models; namespace LLama.Web.Services { diff --git a/LLama.Web/Services/IModelSessionService.cs b/LLama.Web/Services/IModelSessionService.cs index 4ee0d483..8723d795 100644 --- a/LLama.Web/Services/IModelSessionService.cs +++ b/LLama.Web/Services/IModelSessionService.cs @@ -1,16 +1,88 @@ -using LLama.Abstractions; -using LLama.Web.Common; +using LLama.Web.Common; using LLama.Web.Models; namespace LLama.Web.Services { public interface IModelSessionService { + /// + /// Gets the ModelSession with the specified Id. + /// + /// The session identifier. + /// The ModelSession if exists, otherwise null Task GetAsync(string sessionId); - Task> CreateAsync(LLamaExecutorType executorType, string sessionId, string modelName, string promptName, string parameterName); - Task RemoveAsync(string sessionId); + + + /// + /// Gets all ModelSessions + /// + /// A collection oa all Model instances + Task> GetAllAsync(); + + + /// + /// Creates a new ModelSession + /// + /// The session identifier. + /// The session configuration. + /// The default inference configuration, will be used for all inference where no infer configuration is supplied. + /// The cancellation token. + /// + /// + /// Session with id {sessionId} already exists + /// or + /// Failed to create model session + /// + Task CreateAsync(string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); + + + /// + /// Closes the session + /// + /// The session identifier. + /// + Task CloseAsync(string sessionId); + + + /// + /// Runs inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference configuration, if null session default is used + /// The cancellation token. + /// Inference is already running for this session + IAsyncEnumerable InferAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default); + + /// + /// Runs inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference configuration, if null session default is used + /// The cancellation token. + /// Streaming async result of + /// Inference is already running for this session + IAsyncEnumerable InferTextAsync(string sessionId, string prompt, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); + + + /// + /// Queues inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference configuration, if null session default is used + /// The cancellation token. + /// Completed inference result as string + /// Inference is already running for this session + Task InferTextCompleteAsync(string sessionId, string prompt, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); + + + /// + /// Cancels the current inference action. + /// + /// The session identifier. + /// Task CancelAsync(string sessionId); } - - } diff --git a/LLama.Web/Services/ModelLoaderService.cs b/LLama.Web/Services/ModelLoaderService.cs new file mode 100644 index 00000000..7545885d --- /dev/null +++ b/LLama.Web/Services/ModelLoaderService.cs @@ -0,0 +1,42 @@ +namespace LLama.Web.Services +{ + + /// + /// Service for managing loading/preloading of models at app startup + /// + /// Type used to identify contexts + /// + public class ModelLoaderService : IHostedService + { + private readonly IModelService _modelService; + + /// + /// Initializes a new instance of the class. + /// + /// The model service. + public ModelLoaderService(IModelService modelService) + { + _modelService = modelService; + } + + + /// + /// Triggered when the application host is ready to start the service. + /// + /// Indicates that the start process has been aborted. + public async Task StartAsync(CancellationToken cancellationToken) + { + await _modelService.LoadModels(); + } + + + /// + /// Triggered when the application host is performing a graceful shutdown. + /// + /// Indicates that the shutdown process should no longer be graceful. + public async Task StopAsync(CancellationToken cancellationToken) + { + await _modelService.UnloadModels(); + } + } +} diff --git a/LLama.Web/Services/ModelService.cs b/LLama.Web/Services/ModelService.cs index 16365a5d..2a3d4788 100644 --- a/LLama.Web/Services/ModelService.cs +++ b/LLama.Web/Services/ModelService.cs @@ -1,5 +1,6 @@ using LLama.Web.Async; using LLama.Web.Common; +using LLama.Web.Models; using System.Collections.Concurrent; namespace LLama.Web.Services diff --git a/LLama.Web/Services/ModelSessionService.cs b/LLama.Web/Services/ModelSessionService.cs new file mode 100644 index 00000000..e808e630 --- /dev/null +++ b/LLama.Web/Services/ModelSessionService.cs @@ -0,0 +1,216 @@ +using LLama.Web.Async; +using LLama.Web.Common; +using LLama.Web.Models; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Runtime.CompilerServices; + +namespace LLama.Web.Services +{ + /// + /// Example Service for handling a model session for a websockets connection lifetime + /// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc + /// + public class ModelSessionService : IModelSessionService + { + private readonly AsyncGuard _sessionGuard; + private readonly IModelService _modelService; + private readonly ConcurrentDictionary _modelSessions; + + + /// + /// Initializes a new instance of the class. + /// + /// The model service. + /// The model session state service. + public ModelSessionService(IModelService modelService) + { + _modelService = modelService; + _sessionGuard = new AsyncGuard(); + _modelSessions = new ConcurrentDictionary(); + } + + + /// + /// Gets the ModelSession with the specified Id. + /// + /// The session identifier. + /// The ModelSession if exists, otherwise null + public Task GetAsync(string sessionId) + { + return Task.FromResult(_modelSessions.TryGetValue(sessionId, out var session) ? session : null); + } + + + /// + /// Gets all ModelSessions + /// + /// A collection oa all Model instances + public Task> GetAllAsync() + { + return Task.FromResult>(_modelSessions.Values); + } + + + /// + /// Creates a new ModelSession + /// + /// The session identifier. + /// The session configuration. + /// The default inference configuration, will be used for all inference where no infer configuration is supplied. + /// The cancellation token. + /// + /// + /// Session with id {sessionId} already exists + /// or + /// Failed to create model session + /// + public async Task CreateAsync(string sessionId, Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) + { + if (_modelSessions.TryGetValue(sessionId, out _)) + throw new Exception($"Session with id {sessionId} already exists"); + + // Create context + var (model, context) = await _modelService.GetOrCreateModelAndContext(sessionConfig.Model, sessionId); + + // Create session + var modelSession = new ModelSession(model, context, sessionId, sessionConfig, inferenceConfig); + if (!_modelSessions.TryAdd(sessionId, modelSession)) + throw new Exception($"Failed to create model session"); + + // Run initial Prompt + await modelSession.InitializePrompt(inferenceConfig, cancellationToken); + return modelSession; + + } + + + /// + /// Closes the session + /// + /// The session identifier. + /// + public async Task CloseAsync(string sessionId) + { + if (_modelSessions.TryRemove(sessionId, out var modelSession)) + { + modelSession.CancelInfer(); + return await _modelService.RemoveContext(modelSession.ModelName, sessionId); + } + return false; + } + + + /// + /// Runs inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference configuration, if null session default is used + /// The cancellation token. + /// Inference is already running for this session + public async IAsyncEnumerable InferAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (!_sessionGuard.Guard(sessionId)) + throw new Exception($"Inference is already running for this session"); + + try + { + if (!_modelSessions.TryGetValue(sessionId, out var modelSession)) + yield break; + + // Send begin of response + var stopwatch = Stopwatch.GetTimestamp(); + yield return new TokenModel(default, default, TokenType.Begin); + + // Send content of response + await foreach (var token in modelSession.InferAsync(prompt, inferenceConfig, cancellationToken).ConfigureAwait(false)) + { + yield return new TokenModel(default, token); + } + + // Send end of response + var elapsedTime = GetElapsed(stopwatch); + var endTokenType = modelSession.IsInferCanceled() ? TokenType.Cancel : TokenType.End; + var signature = endTokenType == TokenType.Cancel + ? $"Inference cancelled after {elapsedTime / 1000:F0} seconds" + : $"Inference completed in {elapsedTime / 1000:F0} seconds"; + yield return new TokenModel(default, signature, endTokenType); + } + finally + { + _sessionGuard.Release(sessionId); + } + } + + + /// + /// Runs inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference configuration, if null session default is used + /// The cancellation token. + /// Streaming async result of + /// Inference is already running for this session + public IAsyncEnumerable InferTextAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) + { + async IAsyncEnumerable InferTextInternal() + { + await foreach (var token in InferAsync(sessionId, prompt, inferenceConfig, cancellationToken).ConfigureAwait(false)) + { + if (token.TokenType == TokenType.Content) + yield return token.Content; + } + } + return InferTextInternal(); + } + + + /// + /// Runs inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference configuration, if null session default is used + /// The cancellation token. + /// Completed inference result as string + /// Inference is already running for this session + public async Task InferTextCompleteAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) + { + var inferResult = await InferAsync(sessionId, prompt, inferenceConfig, cancellationToken) + .Where(x => x.TokenType == TokenType.Content) + .Select(x => x.Content) + .ToListAsync(cancellationToken: cancellationToken); + + return string.Concat(inferResult); + } + + + /// + /// Cancels the current inference action. + /// + /// The session identifier. + /// + public Task CancelAsync(string sessionId) + { + if (_modelSessions.TryGetValue(sessionId, out var modelSession)) + { + modelSession.CancelInfer(); + return Task.FromResult(true); + } + return Task.FromResult(false); + } + + + /// + /// Gets the elapsed time in milliseconds. + /// + /// The timestamp. + /// + private static int GetElapsed(long timestamp) + { + return (int)Stopwatch.GetElapsedTime(timestamp).TotalMilliseconds; + } + } +} diff --git a/LLama.Web/appsettings.json b/LLama.Web/appsettings.json index 9f340a9c..6231b882 100644 --- a/LLama.Web/appsettings.json +++ b/LLama.Web/appsettings.json @@ -7,48 +7,34 @@ }, "AllowedHosts": "*", "LLamaOptions": { + "ModelLoadType": "Single", "Models": [ { "Name": "WizardLM-7B", - "MaxInstances": 2, + "MaxInstances": 20, "ModelPath": "D:\\Repositories\\AI\\Models\\wizardLM-7B.ggmlv3.q4_0.bin", - "ContextSize": 2048 - } - ], - "Parameters": [ - { - "Name": "Default", - "Temperature": 0.6 - } - ], - "Prompts": [ - { - "Name": "None", - "Prompt": "" - }, - { - "Name": "Alpaca", - "Path": "D:\\Repositories\\AI\\Prompts\\alpaca.txt", - "AntiPrompt": [ - "User:" - ], - "OutputFilter": [ - "Response:", - "User:" - ] - }, - { - "Name": "ChatWithBob", - "Path": "D:\\Repositories\\AI\\Prompts\\chat-with-bob.txt", - "AntiPrompt": [ - "User:" - ], - "OutputFilter": [ - "Bob:", - "User:" - ] + "ContextSize": 2048, + "BatchSize": 2048, + "Threads": 4, + "GpuLayerCount": 6, + "UseMemorymap": true, + "UseMemoryLock": false, + "MainGpu": 0, + "LowVram": false, + "Seed": 1686349486, + "UseFp16Memory": true, + "Perplexity": false, + "LoraAdapter": "", + "LoraBase": "", + "EmbeddingMode": false, + "TensorSplits": null, + "GroupedQueryAttention": 1, + "RmsNormEpsilon": 0.000005, + "RopeFrequencyBase": 10000.0, + "RopeFrequencyScale": 1.0, + "MulMatQ": false, + "Encoding": "UTF-8" } ] - } } diff --git a/LLama.Web/wwwroot/css/site.css b/LLama.Web/wwwroot/css/site.css index d10ef975..14685f45 100644 --- a/LLama.Web/wwwroot/css/site.css +++ b/LLama.Web/wwwroot/css/site.css @@ -22,13 +22,30 @@ footer { @media (min-width: 768px) { - html { - font-size: 16px; - } + html { + font-size: 16px; + } } .btn:focus, .btn:active:focus, .btn-link.nav-link:focus, .form-control:focus, .form-check-input:focus { - box-shadow: 0 0 0 0.1rem white, 0 0 0 0.25rem #258cfb; + box-shadow: 0 0 0 0.1rem white, 0 0 0 0.25rem #258cfb; +} + +#scroll-container { + flex: 1; + overflow-y: scroll; +} + +#output-container .content { + white-space: break-spaces; } +.slider-container > .slider { + width: 100%; +} + +.slider-container > label { + width: 50px; + text-align: center; +} diff --git a/LLama.Web/wwwroot/js/sessionConnectionChat.js b/LLama.Web/wwwroot/js/sessionConnectionChat.js index 472b5971..719c44ac 100644 --- a/LLama.Web/wwwroot/js/sessionConnectionChat.js +++ b/LLama.Web/wwwroot/js/sessionConnectionChat.js @@ -1,26 +1,26 @@ -const createConnectionSessionChat = (LLamaExecutorType) => { +const createConnectionSessionChat = () => { const outputErrorTemplate = $("#outputErrorTemplate").html(); const outputInfoTemplate = $("#outputInfoTemplate").html(); const outputUserTemplate = $("#outputUserTemplate").html(); const outputBotTemplate = $("#outputBotTemplate").html(); - const sessionDetailsTemplate = $("#sessionDetailsTemplate").html(); + const signatureTemplate = $("#signatureTemplate").html(); - let connectionId; + let inferenceSession; const connection = new signalR.HubConnectionBuilder().withUrl("/SessionConnectionHub").build(); const scrollContainer = $("#scroll-container"); const outputContainer = $("#output-container"); const chatInput = $("#input"); - const onStatus = (connection, status) => { - connectionId = connection; if (status == Enums.SessionConnectionStatus.Connected) { $("#socket").text("Connected").addClass("text-success"); } else if (status == Enums.SessionConnectionStatus.Loaded) { + loaderHide(); enableControls(); - $("#session-details").html(Mustache.render(sessionDetailsTemplate, { model: getSelectedModel(), prompt: getSelectedPrompt(), parameter: getSelectedParameter() })); + $("#load").hide(); + $("#unload").show(); onInfo(`New model session successfully started`) } } @@ -36,30 +36,31 @@ const createConnectionSessionChat = (LLamaExecutorType) => { let responseContent; let responseContainer; - let responseFirstFragment; + let responseFirstToken; const onResponse = (response) => { if (!response) return; - if (response.isFirst) { - outputContainer.append(Mustache.render(outputBotTemplate, response)); - responseContainer = $(`#${response.id}`); + if (response.tokenType == Enums.TokenType.Begin) { + const uniqueId = randomString(); + outputContainer.append(Mustache.render(outputBotTemplate, { id: uniqueId, ...response })); + responseContainer = $(`#${uniqueId}`); responseContent = responseContainer.find(".content"); - responseFirstFragment = true; + responseFirstToken = true; scrollToBottom(true); return; } - if (response.isLast) { + if (response.tokenType == Enums.TokenType.End || response.tokenType == Enums.TokenType.Cancel) { enableControls(); - responseContainer.find(".signature").append(response.content); + responseContainer.find(".signature").append(Mustache.render(signatureTemplate, response)); scrollToBottom(); } else { - if (responseFirstFragment) { + if (responseFirstToken) { responseContent.empty(); - responseFirstFragment = false; + responseFirstToken = false; responseContainer.find(".date").append(getDateTime()); } responseContent.append(response.content); @@ -67,45 +68,88 @@ const createConnectionSessionChat = (LLamaExecutorType) => { } } - const sendPrompt = async () => { const text = chatInput.val(); if (text) { + chatInput.val(null); disableControls(); outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() })); - await connection.invoke('SendPrompt', text); - chatInput.val(null); + inferenceSession = await connection + .stream("SendPrompt", text, serializeFormToJson('SessionParameters')) + .subscribe({ + next: onResponse, + complete: onResponse, + error: onError, + }); scrollToBottom(true); } } const cancelPrompt = async () => { - await ajaxPostJsonAsync('?handler=Cancel', { connectionId: connectionId }); + if (inferenceSession) + inferenceSession.dispose(); } const loadModel = async () => { - const modelName = getSelectedModel(); - const promptName = getSelectedPrompt(); - const parameterName = getSelectedParameter(); - if (!modelName || !promptName || !parameterName) { - onError("Please select a valid Model, Parameter and Prompt"); - return; - } - + const sessionParams = serializeFormToJson('SessionParameters'); + loaderShow(); disableControls(); - await connection.invoke('LoadModel', LLamaExecutorType, modelName, promptName, parameterName); + disablePromptControls(); + $("#load").attr("disabled", "disabled"); + + // TODO: Split parameters sets + await connection.invoke('LoadModel', sessionParams, sessionParams); } + const unloadModel = async () => { + disableControls(); + enablePromptControls(); + $("#load").removeAttr("disabled"); + } + + const serializeFormToJson = (form) => { + const formDataJson = {}; + const formData = new FormData(document.getElementById(form)); + formData.forEach((value, key) => { + + if (key.includes(".")) + key = key.split(".")[1]; + + // Convert number strings to numbers + if (!isNaN(value) && value.trim() !== "") { + formDataJson[key] = parseFloat(value); + } + // Convert boolean strings to booleans + else if (value === "true" || value === "false") { + formDataJson[key] = (value === "true"); + } + else { + formDataJson[key] = value; + } + }); + return formDataJson; + } const enableControls = () => { $(".input-control").removeAttr("disabled"); } - const disableControls = () => { $(".input-control").attr("disabled", "disabled"); } + const enablePromptControls = () => { + $("#load").show(); + $("#unload").hide(); + $(".prompt-control").removeAttr("disabled"); + activatePromptTab(); + } + + const disablePromptControls = () => { + $(".prompt-control").attr("disabled", "disabled"); + activateParamsTab(); + } + const clearOutput = () => { outputContainer.empty(); } @@ -117,27 +161,14 @@ const createConnectionSessionChat = (LLamaExecutorType) => { customPrompt.text(selectedValue); } - - const getSelectedModel = () => { - return $("option:selected", "#Model").val(); - } - - - const getSelectedParameter = () => { - return $("option:selected", "#Parameter").val(); - } - - - const getSelectedPrompt = () => { - return $("option:selected", "#Prompt").val(); - } - - const getDateTime = () => { const dateTime = new Date(); return dateTime.toLocaleString(); } + const randomString = () => { + return Math.random().toString(36).slice(2); + } const scrollToBottom = (force) => { const scrollTop = scrollContainer.scrollTop(); @@ -151,10 +182,25 @@ const createConnectionSessionChat = (LLamaExecutorType) => { } } + const activatePromptTab = () => { + $("#nav-prompt-tab").trigger("click"); + } + const activateParamsTab = () => { + $("#nav-params-tab").trigger("click"); + } + + const loaderShow = () => { + $(".spinner").show(); + } + + const loaderHide = () => { + $(".spinner").hide(); + } // Map UI functions $("#load").on("click", loadModel); + $("#unload").on("click", unloadModel); $("#send").on("click", sendPrompt); $("#clear").on("click", clearOutput); $("#cancel").on("click", cancelPrompt); @@ -165,7 +211,10 @@ const createConnectionSessionChat = (LLamaExecutorType) => { sendPrompt(); } }); - + $(".slider").on("input", function (e) { + const slider = $(this); + slider.next().text(slider.val()); + }).trigger("input"); // Map signalr functions diff --git a/LLama.Web/wwwroot/js/site.js b/LLama.Web/wwwroot/js/site.js index 2f679669..6612c772 100644 --- a/LLama.Web/wwwroot/js/site.js +++ b/LLama.Web/wwwroot/js/site.js @@ -40,11 +40,17 @@ const Enums = { Loaded: 4, Connected: 10 }), - LLamaExecutorType: Object.freeze({ + ExecutorType: Object.freeze({ Interactive: 0, Instruct: 1, Stateless: 2 }), + TokenType: Object.freeze({ + Begin: 0, + Content: 2, + End: 4, + Cancel: 10 + }), GetName: (enumType, enumKey) => { return Object.keys(enumType)[enumKey] },