diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs
index 89d94ade..c604dc0d 100644
--- a/LLama.Web/Common/InferenceOptions.cs
+++ b/LLama.Web/Common/InferenceOptions.cs
@@ -1,6 +1,9 @@
-using LLama.Common;
+#nullable enable
+
+using LLama.Common;
using LLama.Abstractions;
using LLama.Native;
+using LLama.Sampling;
namespace LLama.Web.Common
{
@@ -64,6 +67,9 @@ namespace LLama.Web.Common
///
/// A grammar to constrain possible tokens
///
- public SafeLLamaGrammarHandle Grammar { get; set; } = null;
+ public SafeLLamaGrammarHandle? Grammar { get; set; }
+
+ ///
+ public ISamplingPipeline? SamplingPipeline { get; set; }
}
}
diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs
index d87faf0e..e1e89414 100644
--- a/LLama/Abstractions/IInferenceParams.cs
+++ b/LLama/Abstractions/IInferenceParams.cs
@@ -1,6 +1,7 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Native;
+using LLama.Sampling;
namespace LLama.Abstractions
{
@@ -108,5 +109,10 @@ namespace LLama.Abstractions
/// Grammar to constrain possible tokens
///
SafeLLamaGrammarHandle? Grammar { get; set; }
+
+ ///
+ /// Set a custom sampling pipeline to use. If this is set All other sampling parameters are ignored!
+ ///
+ ISamplingPipeline? SamplingPipeline { get; set; }
}
}
\ No newline at end of file
diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs
index d7bd19d9..c1f39550 100644
--- a/LLama/Common/InferenceParams.cs
+++ b/LLama/Common/InferenceParams.cs
@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using LLama.Native;
+using LLama.Sampling;
namespace LLama.Common
{
@@ -76,6 +77,9 @@ namespace LLama.Common
///
public SafeLLamaGrammarHandle? Grammar { get; set; }
+
+ ///
+ public ISamplingPipeline? SamplingPipeline { get; set; }
}
///
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 3a3e51af..2902dc8f 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -10,6 +10,7 @@ using LLama.Common;
using System.Runtime.InteropServices;
using LLama.Extensions;
using LLama.Abstractions;
+using LLama.Sampling;
using Microsoft.Extensions.Logging;
namespace LLama
@@ -212,6 +213,17 @@ namespace LLama
}
}
+ ///
+ /// Sample a single token from this context, using the given sampling pipeline
+ ///
+ /// The pipeline to use to process the logits and to select a token
+ /// The tokens recently returned from the model
+ /// The selected token
+ public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens)
+ {
+ return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
+ }
+
///
/// Perform the sampling. Please don't use it unless you fully know what it does.
///
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index d81630aa..3ed66890 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -210,16 +210,24 @@ namespace LLama
SaveSessionFile(_pathSession);
}
- var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
- inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
+ llama_token id;
+ if (inferenceParams.SamplingPipeline is not null)
+ {
+ id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
+ }
+ else
+ {
+ var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
+ inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
- var mu = MirostatMu;
- var id = Context.Sample(
- tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
- inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
- inferenceParams.MinP
- );
- MirostatMu = mu;
+ var mu = MirostatMu;
+ id = Context.Sample(
+ tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
+ inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
+ inferenceParams.MinP
+ );
+ MirostatMu = mu;
+ }
_last_n_tokens.Enqueue(id);
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 4d28274b..9cecf437 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -189,16 +189,24 @@ namespace LLama
SaveSessionFile(_pathSession);
}
- var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
- inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
+ llama_token id;
+ if (inferenceParams.SamplingPipeline is not null)
+ {
+ id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
+ }
+ else
+ {
+ var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
+ inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
- var mu = MirostatMu;
- var id = Context.Sample(
- tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
- inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
- inferenceParams.MinP
- );
- MirostatMu = mu;
+ var mu = MirostatMu;
+ id = Context.Sample(
+ tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
+ inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
+ inferenceParams.MinP
+ );
+ MirostatMu = mu;
+ }
_last_n_tokens.Enqueue(id);
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index 9c41af7c..831aceb2 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -7,6 +7,7 @@ using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
+using LLama.Sampling;
using Microsoft.Extensions.Logging;
namespace LLama
@@ -85,16 +86,24 @@ namespace LLama
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
- // Penalize the generated tokens by various penalties
- var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
- inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
+ llama_token id;
+ if (inferenceParams.SamplingPipeline is not null)
+ {
+ id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
+ }
+ else
+ {
+ // Penalize the generated tokens by various penalties
+ var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
+ inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
- // Sample a single token
- var id = Context.Sample(
- tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
- inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
- inferenceParams.MinP
- );
+ // Sample a single token
+ id = Context.Sample(
+ tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
+ inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
+ inferenceParams.MinP
+ );
+ }
// Decode this token into text
decoder.Add(id);
diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs
index 489f2c5a..4540e9fc 100644
--- a/LLama/Sampling/ISamplingPipeline.cs
+++ b/LLama/Sampling/ISamplingPipeline.cs
@@ -1,5 +1,7 @@
using System;
+using System.Buffers;
using System.Collections.Generic;
+using System.Runtime.InteropServices;
using LLama.Native;
using LLama.Sampling.Logits;
using LLama.Sampling.Selection;
@@ -16,9 +18,9 @@ public interface ISamplingPipeline
///
/// Sample a single token from the given logits
///
- ///
- ///
- ///
+ /// The context being sampled from
+ /// The logits produced by the model
+ /// A span of tokens recently returned by the model
///
int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens);
@@ -28,10 +30,43 @@ public interface ISamplingPipeline
void Reset();
}
+///
+/// Extensions methods for ISamplingPipeline
+///
+public static class ISamplingPipelineExtensions
+{
+ ///
+ /// Sample a single token from the given logits
+ ///
+ ///
+ /// The context being sampled from
+ /// The logits produced by the model
+ /// A list of tokens recently returned by the model
+ ///
+ public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens)
+ {
+#if NET5_0_OR_GREATER
+ var span = CollectionsMarshal.AsSpan(lastTokens);
+ return pipeline.Sample(ctx, logits, span);
+#else
+ var copy = ArrayPool.Shared.Rent(lastTokens.Count);
+ try
+ {
+ lastTokens.CopyTo(copy);
+ return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length));
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(copy);
+ }
+#endif
+ }
+}
+
///
/// Simple implementation of `ISamplingPipeline`, applies processors in order every time
///
-public sealed class BasicSamplingPipeline
+public sealed class ConfigurableSamplingPipeline
: ISamplingPipeline
{
///