From b34f72a883a8851cafb6fb6e3ebca9fa2c0e3a29 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 8 Dec 2023 01:02:27 +0000 Subject: [PATCH] - Added `SamplingPipeline` to inference params which overrides all other options with an entirely custom pipeline. - Added a `Sample` method to `LLamaContext` which uses a custom pipeline - Modified all executors to use the custom pipeline if it exists --- LLama.Web/Common/InferenceOptions.cs | 10 ++++-- LLama/Abstractions/IInferenceParams.cs | 6 ++++ LLama/Common/InferenceParams.cs | 4 +++ LLama/LLamaContext.cs | 12 +++++++ LLama/LLamaInstructExecutor.cs | 26 ++++++++++------ LLama/LLamaInteractExecutor.cs | 26 ++++++++++------ LLama/LLamaStatelessExecutor.cs | 27 ++++++++++------ LLama/Sampling/ISamplingPipeline.cs | 43 +++++++++++++++++++++++--- 8 files changed, 121 insertions(+), 33 deletions(-) diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs index 89d94ade..c604dc0d 100644 --- a/LLama.Web/Common/InferenceOptions.cs +++ b/LLama.Web/Common/InferenceOptions.cs @@ -1,6 +1,9 @@ -using LLama.Common; +#nullable enable + +using LLama.Common; using LLama.Abstractions; using LLama.Native; +using LLama.Sampling; namespace LLama.Web.Common { @@ -64,6 +67,9 @@ namespace LLama.Web.Common /// /// A grammar to constrain possible tokens /// - public SafeLLamaGrammarHandle Grammar { get; set; } = null; + public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + public ISamplingPipeline? SamplingPipeline { get; set; } } } diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index d87faf0e..e1e89414 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using LLama.Common; using LLama.Native; +using LLama.Sampling; namespace LLama.Abstractions { @@ -108,5 +109,10 @@ namespace LLama.Abstractions /// Grammar to constrain possible tokens /// SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + /// Set a custom sampling pipeline to use. If this is set All other sampling parameters are ignored! + /// + ISamplingPipeline? SamplingPipeline { get; set; } } } \ No newline at end of file diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index d7bd19d9..c1f39550 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using LLama.Native; +using LLama.Sampling; namespace LLama.Common { @@ -76,6 +77,9 @@ namespace LLama.Common /// public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + public ISamplingPipeline? SamplingPipeline { get; set; } } /// diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 3a3e51af..2902dc8f 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -10,6 +10,7 @@ using LLama.Common; using System.Runtime.InteropServices; using LLama.Extensions; using LLama.Abstractions; +using LLama.Sampling; using Microsoft.Extensions.Logging; namespace LLama @@ -212,6 +213,17 @@ namespace LLama } } + /// + /// Sample a single token from this context, using the given sampling pipeline + /// + /// The pipeline to use to process the logits and to select a token + /// The tokens recently returned from the model + /// The selected token + public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens) + { + return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); + } + /// /// Perform the sampling. Please don't use it unless you fully know what it does. /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index d81630aa..3ed66890 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -210,16 +210,24 @@ namespace LLama SaveSessionFile(_pathSession); } - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + } + else + { + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var mu = MirostatMu; - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - MirostatMu = mu; + var mu = MirostatMu; + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + MirostatMu = mu; + } _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 4d28274b..9cecf437 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -189,16 +189,24 @@ namespace LLama SaveSessionFile(_pathSession); } - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + } + else + { + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var mu = MirostatMu; - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - MirostatMu = mu; + var mu = MirostatMu; + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + MirostatMu = mu; + } _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 9c41af7c..831aceb2 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using LLama.Native; +using LLama.Sampling; using Microsoft.Extensions.Logging; namespace LLama @@ -85,16 +86,24 @@ namespace LLama var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) { - // Penalize the generated tokens by various penalties - var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); + } + else + { + // Penalize the generated tokens by various penalties + var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - // Sample a single token - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); + // Sample a single token + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + } // Decode this token into text decoder.Add(id); diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index 489f2c5a..4540e9fc 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -1,5 +1,7 @@ using System; +using System.Buffers; using System.Collections.Generic; +using System.Runtime.InteropServices; using LLama.Native; using LLama.Sampling.Logits; using LLama.Sampling.Selection; @@ -16,9 +18,9 @@ public interface ISamplingPipeline /// /// Sample a single token from the given logits /// - /// - /// - /// + /// The context being sampled from + /// The logits produced by the model + /// A span of tokens recently returned by the model /// int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); @@ -28,10 +30,43 @@ public interface ISamplingPipeline void Reset(); } +/// +/// Extensions methods for ISamplingPipeline +/// +public static class ISamplingPipelineExtensions +{ + /// + /// Sample a single token from the given logits + /// + /// + /// The context being sampled from + /// The logits produced by the model + /// A list of tokens recently returned by the model + /// + public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens) + { +#if NET5_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(lastTokens); + return pipeline.Sample(ctx, logits, span); +#else + var copy = ArrayPool.Shared.Rent(lastTokens.Count); + try + { + lastTokens.CopyTo(copy); + return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length)); + } + finally + { + ArrayPool.Shared.Return(copy); + } +#endif + } +} + /// /// Simple implementation of `ISamplingPipeline`, applies processors in order every time /// -public sealed class BasicSamplingPipeline +public sealed class ConfigurableSamplingPipeline : ISamplingPipeline { ///