From bac9cba01a9a2bb065deefd534ec0f2c599e3bfd Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sun, 6 Aug 2023 11:03:45 +1200 Subject: [PATCH] InferenceParams abstractions --- LLama.Web/Common/ParameterOptions.cs | 96 +++++++++++++++++++- LLama/Abstractions/IInferenceParams.cs | 117 +++++++++++++++++++++++++ LLama/Abstractions/ILLamaExecutor.cs | 4 +- LLama/ChatSession.cs | 12 +-- LLama/Common/InferenceParams.cs | 5 +- LLama/LLamaExecutorBase.cs | 8 +- LLama/LLamaInstructExecutor.cs | 7 +- LLama/LLamaInteractExecutor.cs | 5 +- LLama/LLamaStatelessExecutor.cs | 4 +- 9 files changed, 234 insertions(+), 24 deletions(-) create mode 100644 LLama/Abstractions/IInferenceParams.cs diff --git a/LLama.Web/Common/ParameterOptions.cs b/LLama.Web/Common/ParameterOptions.cs index 3cdd3701..7677f04a 100644 --- a/LLama.Web/Common/ParameterOptions.cs +++ b/LLama.Web/Common/ParameterOptions.cs @@ -1,9 +1,99 @@ using LLama.Common; +using LLama.Abstractions; namespace LLama.Web.Common { - public class ParameterOptions : InferenceParams - { + public class ParameterOptions : IInferenceParams + { public string Name { get; set; } - } + + + + /// + /// number of tokens to keep from initial prompt + /// + public int TokensKeep { get; set; } = 0; + /// + /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response + /// until it complete. + /// + public int MaxTokens { get; set; } = -1; + /// + /// logit bias for specific tokens + /// + public Dictionary? LogitBias { get; set; } = null; + + /// + /// Sequences where the model will stop generating further tokens. + /// + public IEnumerable AntiPrompts { get; set; } = Array.Empty(); + /// + /// path to file for saving/loading model eval state + /// + public string PathSession { get; set; } = string.Empty; + /// + /// string to suffix user inputs with + /// + public string InputSuffix { get; set; } = string.Empty; + /// + /// string to prefix user inputs with + /// + public string InputPrefix { get; set; } = string.Empty; + /// + /// 0 or lower to use vocab size + /// + public int TopK { get; set; } = 40; + /// + /// 1.0 = disabled + /// + public float TopP { get; set; } = 0.95f; + /// + /// 1.0 = disabled + /// + public float TfsZ { get; set; } = 1.0f; + /// + /// 1.0 = disabled + /// + public float TypicalP { get; set; } = 1.0f; + /// + /// 1.0 = disabled + /// + public float Temperature { get; set; } = 0.8f; + /// + /// 1.0 = disabled + /// + public float RepeatPenalty { get; set; } = 1.1f; + /// + /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) + /// + public int RepeatLastTokensCount { get; set; } = 64; + /// + /// frequency penalty coefficient + /// 0.0 = disabled + /// + public float FrequencyPenalty { get; set; } = .0f; + /// + /// presence penalty coefficient + /// 0.0 = disabled + /// + public float PresencePenalty { get; set; } = .0f; + /// + /// Mirostat uses tokens instead of words. + /// algorithm described in the paper https://arxiv.org/abs/2007.14966. + /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + /// + public MirostatType Mirostat { get; set; } = MirostatType.Disable; + /// + /// target entropy + /// + public float MirostatTau { get; set; } = 5.0f; + /// + /// learning rate + /// + public float MirostatEta { get; set; } = 0.1f; + /// + /// consider newlines as a repeatable token (penalize_nl) + /// + public bool PenalizeNL { get; set; } = true; + } } diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs new file mode 100644 index 00000000..73cbbfd2 --- /dev/null +++ b/LLama/Abstractions/IInferenceParams.cs @@ -0,0 +1,117 @@ +using System.Collections.Generic; +using LLama.Common; + +namespace LLama.Abstractions +{ + /// + /// The paramters used for inference. + /// + public interface IInferenceParams + { + /// + /// number of tokens to keep from initial prompt + /// + public int TokensKeep { get; set; } + + /// + /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response + /// until it complete. + /// + public int MaxTokens { get; set; } + + /// + /// logit bias for specific tokens + /// + public Dictionary? LogitBias { get; set; } + + + /// + /// Sequences where the model will stop generating further tokens. + /// + public IEnumerable AntiPrompts { get; set; } + + /// + /// path to file for saving/loading model eval state + /// + public string PathSession { get; set; } + + /// + /// string to suffix user inputs with + /// + public string InputSuffix { get; set; } + + /// + /// string to prefix user inputs with + /// + public string InputPrefix { get; set; } + + /// + /// 0 or lower to use vocab size + /// + public int TopK { get; set; } + + /// + /// 1.0 = disabled + /// + public float TopP { get; set; } + + /// + /// 1.0 = disabled + /// + public float TfsZ { get; set; } + + /// + /// 1.0 = disabled + /// + public float TypicalP { get; set; } + + /// + /// 1.0 = disabled + /// + public float Temperature { get; set; } + + /// + /// 1.0 = disabled + /// + public float RepeatPenalty { get; set; } + + /// + /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) + /// + public int RepeatLastTokensCount { get; set; } + + /// + /// frequency penalty coefficient + /// 0.0 = disabled + /// + public float FrequencyPenalty { get; set; } + + /// + /// presence penalty coefficient + /// 0.0 = disabled + /// + public float PresencePenalty { get; set; } + + /// + /// Mirostat uses tokens instead of words. + /// algorithm described in the paper https://arxiv.org/abs/2007.14966. + /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + /// + public MirostatType Mirostat { get; set; } + + /// + /// target entropy + /// + public float MirostatTau { get; set; } + + /// + /// learning rate + /// + public float MirostatEta { get; set; } + + /// + /// consider newlines as a repeatable token (penalize_nl) + /// + public bool PenalizeNL { get; set; } + } +} \ No newline at end of file diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index d35e075e..6a750895 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -23,7 +23,7 @@ namespace LLama.Abstractions /// Any additional parameters /// A cancellation token. /// - IEnumerable Infer(string text, InferenceParams? inferenceParams = null, CancellationToken token = default); + IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); /// /// Asynchronously infers a response from the model. @@ -32,6 +32,6 @@ namespace LLama.Abstractions /// Any additional parameters /// A cancellation token. /// - IAsyncEnumerable InferAsync(string text, InferenceParams? inferenceParams = null, CancellationToken token = default); + IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); } } diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index b87e8984..4a4544b0 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -138,7 +138,7 @@ namespace LLama /// /// /// - public IEnumerable Chat(ChatHistory history, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) + public IEnumerable Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { var prompt = HistoryTransform.HistoryToText(history); History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); @@ -159,7 +159,7 @@ namespace LLama /// /// /// - public IEnumerable Chat(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) + public IEnumerable Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { foreach(var inputTransform in InputTransformPipeline) { @@ -182,7 +182,7 @@ namespace LLama /// /// /// - public async IAsyncEnumerable ChatAsync(ChatHistory history, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var prompt = HistoryTransform.HistoryToText(history); History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); @@ -202,7 +202,7 @@ namespace LLama /// /// /// - public async IAsyncEnumerable ChatAsync(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { foreach (var inputTransform in InputTransformPipeline) { @@ -218,13 +218,13 @@ namespace LLama History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); } - private IEnumerable ChatInternal(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) + private IEnumerable ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { var results = _executor.Infer(prompt, inferenceParams, cancellationToken); return OutputTransform.Transform(results); } - private async IAsyncEnumerable ChatAsyncInternal(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); await foreach (var item in OutputTransform.TransformAsync(results)) diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index 77af7eaf..001a8f8e 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -1,4 +1,5 @@ -using System; +using LLama.Abstractions; +using System; using System.Collections.Generic; namespace LLama.Common @@ -7,7 +8,7 @@ namespace LLama.Common /// /// The paramters used for inference. /// - public class InferenceParams + public class InferenceParams : IInferenceParams { /// /// number of tokens to keep from initial prompt diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index afbc0f25..dbd1b593 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -231,13 +231,13 @@ namespace LLama /// /// /// - protected abstract bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs); + protected abstract bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs); /// /// The core inference logic. /// /// /// - protected abstract void InferInternal(InferenceParams inferenceParams, InferStateArgs args); + protected abstract void InferInternal(IInferenceParams inferenceParams, InferStateArgs args); /// /// Save the current state to a file. /// @@ -267,7 +267,7 @@ namespace LLama /// /// /// - public virtual IEnumerable Infer(string text, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) + public virtual IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); if (inferenceParams is null) @@ -324,7 +324,7 @@ namespace LLama /// /// /// - public virtual async IAsyncEnumerable InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public virtual async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { foreach (var result in Infer(text, inferenceParams, cancellationToken)) { diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 89fbac59..e055c147 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -1,4 +1,5 @@ -using LLama.Common; +using LLama.Abstractions; +using LLama.Common; using LLama.Native; using System; using System.Collections.Generic; @@ -136,7 +137,7 @@ namespace LLama } } /// - protected override bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs) + protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs) { extraOutputs = null; if (_embed_inps.Count <= _consumedTokensCount) @@ -179,7 +180,7 @@ namespace LLama return false; } /// - protected override void InferInternal(InferenceParams inferenceParams, InferStateArgs args) + protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { if (_embeds.Count > 0) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index bc3a242e..f5c1583e 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -1,5 +1,6 @@ using LLama.Common; using LLama.Native; +using LLama.Abstractions; using System; using System.Collections.Generic; using System.IO; @@ -122,7 +123,7 @@ namespace LLama /// /// /// - protected override bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs) + protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs) { extraOutputs = null; if (_embed_inps.Count <= _consumedTokensCount) @@ -166,7 +167,7 @@ namespace LLama } /// - protected override void InferInternal(InferenceParams inferenceParams, InferStateArgs args) + protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { if (_embeds.Count > 0) { diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 88fa1695..dd0497c9 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -36,7 +36,7 @@ namespace LLama } /// - public IEnumerable Infer(string text, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) + public IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); int n_past = 1; @@ -123,7 +123,7 @@ namespace LLama } /// - public async IAsyncEnumerable InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { foreach (var result in Infer(text, inferenceParams, cancellationToken)) {