- Removed the object wrappers and configurable pipeline, they can be better written in code.

- Added BaseSamplingPipeline which provides a base impl of `ISamplingPipeline`
 - Added `DefaultSamplingPipeline` which mimics normal llama.cpp sampling
This commit is contained in:
Martin Evans 2023-12-08 16:25:13 +00:00
parent 3afc007499
commit 835958398c
22 changed files with 309 additions and 844 deletions

View File

@ -1,5 +1,4 @@
using System.Text;
using LLama.Exceptions;
using LLama.Exceptions;
using LLama.Native;
using LLama.Grammars;

View File

@ -1,9 +1,6 @@
using System.Diagnostics;
using LLama.Common;
using LLama.Sampling;
using LLama.Sampling.Logits;
using LLama.Sampling.Selection;
using LLama.Sampling.Tokens;
using Xunit.Abstractions;
namespace LLama.Unittest
@ -35,40 +32,12 @@ namespace LLama.Unittest
public async Task Stateless()
{
// Create a custom pipeline that mimics the default pipeline
var pipeline = new ConfigurableSamplingPipeline()
{
ProtectedLogits =
{
_weights.NewlineToken,
_weights.BeginningOfSentenceToken,
_weights.EndOfSentenceToken
},
LogitProcessors =
{
new LogitBias
{
Biases =
{
{ _weights.NewlineToken, 1000 }, // This is an insane bias, but because newline is a protected logit it will do nothing!
{ 42, 0f },
}
}
},
TokenDataProcessors =
{
new TailFreeSampling { Z = 1 },
new LocallyTypicalSampling { P = 1 },
new TopPSampling { P = 0.95f },
new MinPSampling { P = 0.05f },
new TemperatureSampling { Temperature = 0.8f },
},
Selector = new StandardSelection(),
};
var pipeline = new DefaultSamplingPipeline();
var executor = new StatelessExecutor(_weights, _params);
const string question = "Question. what is a cat?\nAnswer: ";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline};
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };
var timer = new Stopwatch();
timer.Start();

View File

@ -46,14 +46,41 @@ namespace LLama.Native
return new LLamaTokenDataArray(candidates);
}
/// <summary>
/// Overwrite the logit values for all given tokens
/// </summary>
/// <param name="values">tuples of token and logit value to overwrite</param>
public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values)
{
if (values.Length == 0)
return;
var dataSpan = data.Span;
foreach (var (token, value) in values)
{
for (var i = 0; i < data.Length; i++)
{
if (dataSpan[i].id == token)
{
dataSpan[i].logit = value;
break;
}
}
}
sorted = false;
}
#region sampling
/// <summary>
/// Apply grammar rules to candidate tokens
/// </summary>
/// <param name="ctx"></param>
/// <param name="grammar"></param>
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar)
{
if (grammar == null)
return;
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_grammar(ctx, ref st, grammar);

View File

@ -0,0 +1,128 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using LLama.Native;
namespace LLama.Sampling;
/// <summary>
/// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`.
/// </summary>
public abstract class BaseSamplingPipeline
: ISamplingPipeline
{
private int _savedLogitsCount;
private (int index, float logit)[]? _savedLogits;
/// <inheritdoc/>
public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
{
var protectedLogits = GetProtectedTokens(ctx);
_savedLogitsCount = protectedLogits.Count;
_savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount);
try
{
// Save the values of protected logits
for (var i = 0; i < protectedLogits.Count; i++)
{
var index = protectedLogits[i];
var value = logits[index];
_savedLogits[i] = (index, value);
}
// Process raw logits
ProcessLogits(ctx, logits, lastTokens);
// Automatically restore saved logit values after processing
RestoreProtectedTokens(logits);
// Convert logits into token candidates
var candidates = LLamaTokenDataArray.Create(logits);
// Process token data array
ProcessTokenDataArray(ctx, candidates, lastTokens);
// Choose the final value
return ChooseToken(ctx, candidates);
}
finally
{
ArrayPool<(int, float)>.Shared.Return(_savedLogits);
_savedLogits = null;
_savedLogitsCount = 0;
}
}
#region protected tokens
/// <summary>
/// Get all of the "protected" tokens that cannot be changed by ProcessLogits
/// </summary>
/// <returns></returns>
protected abstract IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx);
/// <summary>
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
/// </summary>
/// <param name="logits"></param>
protected void RestoreProtectedTokens(Span<float> logits)
{
if (_savedLogits == null)
return;
// The array may be bigger than necessary, get a span of the valid bit
var saved = _savedLogits.AsSpan(0, _savedLogitsCount);
// Restore the values of protected logits
for (var i = 0; i < saved.Length; i++)
logits[saved[i].index] = saved[i].logit;
}
/// <summary>
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
/// </summary>
/// <param name="candidates"></param>
protected void RestoreProtectedTokens(LLamaTokenDataArray candidates)
{
if (_savedLogits == null || _savedLogits.Length == 0)
return;
candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount));
}
#endregion
/// <summary>
/// Process the raw logit values
/// </summary>
/// <param name="ctx">The context being sampled from</param>
/// <param name="logits">The logits produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);
/// <summary>
/// Process the LLamaTokenDataArray and select a single token
/// </summary>
/// <param name="ctx">The context being sampled from</param>
/// <param name="candidates">The LLamaTokenDataArray data produced by the model</param>
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
/// <returns></returns>
protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens);
/// <summary>
/// Choose the final token from the candidates
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <returns></returns>
protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates);
/// <inheritdoc/>
public virtual void Reset()
{
}
/// <inheritdoc/>
public virtual void Dispose()
{
GC.SuppressFinalize(this);
}
}

View File

@ -0,0 +1,149 @@
using System;
using System.Collections.Generic;
using LLama.Extensions;
using LLama.Native;
namespace LLama.Sampling;
/// <summary>
/// An implementation of ISamplePipeline which mimics the default llama.cpp sampling
/// </summary>
public sealed class DefaultSamplingPipeline
: BaseSamplingPipeline
{
/// <summary>
/// Bias values to add to certain logits
/// </summary>
public Dictionary<int, float> LogitBias { get; } = new();
/// <summary>
/// Grammar to constrain valid tokens
/// </summary>
public SafeLLamaGrammarHandle? Grammar { get; set; }
/// <summary>
/// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
/// </summary>
public float RepeatPenalty { get; set; } = 1.1f;
/// <summary>
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
/// </summary>
public float AlphaFrequency
{
get => _alphaFreq;
set
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaFreq = value;
}
}
private float _alphaFreq = 0.1f;
/// <summary>
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
/// text so far, increasing the model's likelihood to talk about new topics.
/// </summary>
public float AlphaPresence
{
get => _alphaPresence;
set
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaPresence = value;
}
}
private float _alphaPresence = 0.1f;
/// <summary>
/// Temperature to apply (higher temperature is more "creative")
/// </summary>
public float Temperature { get; set; } = 0.75f;
/// <summary>
/// Number of tokens to keep in TopK sampling
/// </summary>
public int TopK { get; set; }
/// <summary>
/// Z value for tail free sampling
/// </summary>
public float TailFreeZ { get; set; }
/// <summary>
/// P value for locally typical sampling
/// </summary>
public float TypicalP { get; set; }
/// <summary>
/// P value for TopP sampling
/// </summary>
public float TopP { get; set; } = 1f;
/// <summary>
/// P value for MinP sampling
/// </summary>
public float MinP { get; set; }
/// <summary>
/// Whether the newline value should be protected from being modified by logit bias and repeat penalty
/// </summary>
public bool PenalizeNewline { get; set; } = false;
private readonly int[] _newlineToken = new int[1];
/// <inheritdoc />
protected override IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx)
{
if (PenalizeNewline)
return Array.Empty<int>();
_newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle);
return _newlineToken;
}
/// <inheritdoc />
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
{
foreach (var (key, value) in LogitBias)
logits[key] += value;
}
/// <inheritdoc />
protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
{
// Apply penalties to candidates
candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);
// Restore protected tokens, so they are not affected by repetition penalties
RestoreProtectedTokens(candidates);
// Apply the normal llama.cpp pipeline
candidates.ApplyGrammar(ctx, Grammar);
candidates.TopK(ctx, TopK);
candidates.TailFree(ctx, TailFreeZ);
candidates.LocallyTypical(ctx, TypicalP);
candidates.TopP(ctx, TopP);
candidates.MinP(ctx, MinP);
candidates.Temperature(ctx, Temperature);
var id = candidates.SampleToken(ctx);
Grammar?.AcceptToken(ctx, id);
return id;
}
/// <inheritdoc />
protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
return candidates.SampleToken(ctx);
}
}

View File

@ -3,14 +3,11 @@ using System.Buffers;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using LLama.Native;
using LLama.Sampling.Logits;
using LLama.Sampling.Selection;
using LLama.Sampling.Tokens;
namespace LLama.Sampling;
/// <summary>
/// Convert a span of logits into a single sampled token
/// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process.
/// </summary>
public interface ISamplingPipeline
: IDisposable
@ -61,101 +58,4 @@ public static class ISamplingPipelineExtensions
}
#endif
}
}
/// <summary>
/// Simple implementation of `ISamplingPipeline`, applies processors in order every time
/// </summary>
public sealed class ConfigurableSamplingPipeline
: ISamplingPipeline
{
/// <summary>
/// Logit processors to apply in this pipeline
/// </summary>
public IList<ILogitProcessor> LogitProcessors { get; } = new List<ILogitProcessor>();
/// <summary>
/// Logits values which will not be changed by the logit processors
/// </summary>
public IList<int> ProtectedLogits { get; } = new List<int>();
/// <summary>
/// Token data processors to apply in this pipeline
/// </summary>
public IList<ITokenDataProcessor> TokenDataProcessors { get; } = new List<ITokenDataProcessor>();
/// <summary>
/// The selector to choose the final token
/// </summary>
public ITokenSelector Selector { get; set; } = new StandardSelection();
/// <inheritdoc />
public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
{
var savedLogitsCount = ProtectedLogits.Count;
var savedLogitValues = ArrayPool<float>.Shared.Rent(savedLogitsCount);
var savedLogitIndices = ArrayPool<int>.Shared.Rent(savedLogitsCount);
try
{
// Save the values of protected logits
for (var i = 0; i < ProtectedLogits.Count; i++)
{
savedLogitValues[i] = logits[ProtectedLogits[i]];
savedLogitIndices[i] = ProtectedLogits[i];
}
// Modify raw logits
foreach (var logitProcessor in LogitProcessors)
logitProcessor.ProcessLogits(ctx, logits, lastTokens);
// Restore the values of protected logits
for (var i = 0; i < savedLogitsCount; i++)
logits[savedLogitIndices[i]] = savedLogitValues[i];
}
finally
{
ArrayPool<float>.Shared.Return(savedLogitValues);
ArrayPool<int>.Shared.Return(savedLogitIndices);
}
// Convert logits into token candidates
var candidates_p = LLamaTokenDataArray.Create(logits);
// Process token candidates
foreach (var tokenDataProcessor in TokenDataProcessors)
tokenDataProcessor.ProcessTokens(ctx, candidates_p, lastTokens);
// Select a token
var token = Selector.Select(ctx, candidates_p, lastTokens);
// Tell processors what was selected
foreach (var logitProcessor in LogitProcessors)
logitProcessor.AcceptToken(ctx, token);
foreach (var tokenDataProcessor in TokenDataProcessors)
tokenDataProcessor.AcceptToken(ctx, token);
return token;
}
/// <inheritdoc />
public void Reset()
{
foreach (var logitProcessor in LogitProcessors)
logitProcessor.Reset();
foreach (var tokenDataProcessor in TokenDataProcessors)
tokenDataProcessor.Reset();
Selector.Reset();
}
/// <inheritdoc />
public void Dispose()
{
foreach (var logitProcessor in LogitProcessors)
logitProcessor.Dispose();
foreach (var tokenDataProcessor in TokenDataProcessors)
tokenDataProcessor.Dispose();
Selector.Dispose();
}
}

View File

@ -1,34 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Logits;
using llama_token = Int32;
/// <summary>
/// Processes raw logits before sampling, applying penalties to certain tokens
/// </summary>
public interface ILogitProcessor
: IDisposable
{
/// <summary>
/// Process raw logits, indexed by llama_token
/// </summary>
/// <param name="ctx">The context this is operating in</param>
/// <param name="logits">The token data array to process</param>
/// <param name="lastTokens">The most recent tokens output</param>
/// <returns>LLamaTokenDataArray, created from logits</returns>
void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<llama_token> lastTokens);
/// <summary>
/// Inform this process when a token is accepted by the model
/// </summary>
/// <param name="ctx"></param>
/// <param name="token"></param>
void AcceptToken(SafeLLamaContextHandle ctx, int token);
/// <summary>
/// Reset all internal sampling state
/// </summary>
void Reset();
}

View File

@ -1,39 +0,0 @@
using System;
using System.Collections.Generic;
using LLama.Native;
namespace LLama.Sampling.Logits;
/// <summary>
/// Add a bias directly to logit values
/// </summary>
public sealed class LogitBias
: ILogitProcessor
{
/// <summary>
/// Biases to apply, token -> bias
/// </summary>
public IDictionary<int, float> Biases { get; } = new Dictionary<int, float>();
/// <inheritdoc />
public void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
{
foreach (var kvp in Biases)
logits[kvp.Key] += kvp.Value;
}
/// <inheritdoc />
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
{
}
/// <inheritdoc />
public void Reset()
{
}
/// <inheritdoc />
public void Dispose()
{
}
}

View File

@ -1,27 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Selection;
/// <summary>
/// Select the most likely token
/// </summary>
public sealed class GreedySelection
: ITokenSelector
{
/// <inheritdoc />
public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
{
return candidates.SampleTokenGreedy(ctx);
}
/// <inheritdoc />
public void Reset()
{
}
/// <inheritdoc />
public void Dispose()
{
}
}

View File

@ -1,25 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Selection;
/// <summary>
/// Select a single token from a set of possibilities
/// </summary>
public interface ITokenSelector
: IDisposable
{
/// <summary>
/// Select a single token
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates"></param>
/// <param name="lastTokens"></param>
/// <returns></returns>
int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens);
/// <summary>
/// Reset the state
/// </summary>
void Reset();
}

View File

@ -1,65 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Selection;
/// <summary>
/// Select a token using Mirostat sampling.
/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// </summary>
public sealed class Mirostat2Selection
: ITokenSelector
{
private float _mu;
/// <summary>
/// Current value of Mu, updated based on the difference between target surprise and actual surprise
/// </summary>
public float Mu
{
get => _mu;
set => _mu = value;
}
/// <summary>
/// The target cross-entropy (or surprise) value you want to achieve for the generated text.
/// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text
/// </summary>
public float Tau { get; set; }
/// <summary>
/// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word.
/// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// </summary>
public float Eta { get; set; }
/// <summary>
/// Create a new Mirostat 2.0 sampler
/// </summary>
/// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text.
/// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text</param>
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word.
/// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
public Mirostat2Selection(float tau, float eta)
{
Tau = tau;
Eta = eta;
}
/// <inheritdoc />
public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
{
return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu);
}
/// <inheritdoc />
public void Reset()
{
_mu = 2 * Tau;
}
/// <inheritdoc />
public void Dispose()
{
}
}

View File

@ -1,76 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Selection;
/// <summary>
/// Select a token using Mirostat sampling.
/// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// </summary>
public sealed class MirostatSelection
: ITokenSelector
{
private float _mu;
/// <summary>
/// Current value of Mu, updated based on the difference between target surprise and actual surprise
/// </summary>
public float Mu
{
get => _mu;
set => _mu = value;
}
/// <summary>
/// The target cross-entropy (or surprise) value you want to achieve for the generated text.
/// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text
/// </summary>
public float Tau { get; set; }
/// <summary>
/// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word.
/// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// </summary>
public float Eta { get; set; }
/// <summary>
/// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn
/// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects
/// the performance of the algorithm.
/// </summary>
public int M { get; set; }
/// <summary>
/// Create a new Mirostat 2.0 sampler
/// </summary>
/// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text.
/// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text</param>
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word.
/// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
/// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn
/// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects
/// the performance of the algorithm.</param>
public MirostatSelection(float tau, float eta, int m = 100)
{
Tau = tau;
Eta = eta;
M = m;
}
/// <inheritdoc />
public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
{
return candidates.SampleTokenMirostat(ctx, Tau, Eta, M, ref _mu);
}
/// <inheritdoc />
public void Reset()
{
_mu = 2 * Tau;
}
/// <inheritdoc />
public void Dispose()
{
}
}

View File

@ -1,27 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Selection;
/// <summary>
/// Select from all possible tokens according to their probability
/// </summary>
public sealed class StandardSelection
: ITokenSelector
{
/// <inheritdoc />
public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
{
return candidates.SampleToken(ctx);
}
/// <inheritdoc />
public void Reset()
{
}
/// <inheritdoc />
public void Dispose()
{
}
}

View File

@ -1,59 +0,0 @@
using System;
using LLama.Grammars;
using LLama.Native;
namespace LLama.Sampling.Tokens;
/// <summary>
/// Apply a grammar to prevent sampling tokens which do not match the grammar
/// </summary>
public sealed class GrammarSampling
: ITokenDataProcessor
{
private SafeLLamaGrammarHandle? _handle;
/// <summary>
/// Grammar to use for sampling
/// </summary>
public Grammar? Grammar { get; set; }
/// <summary>
/// Create a new
/// </summary>
/// <param name="grammar"></param>
public GrammarSampling(Grammar grammar)
{
Grammar = grammar;
}
/// <inheritdoc />
public void Reset()
{
_handle?.Dispose();
_handle = null;
}
/// <inheritdoc />
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
{
// Create a new grammar instance if necessary
_handle ??= Grammar?.CreateInstance();
// Apply it
if (_handle != null)
tokens.ApplyGrammar(ctx, _handle);
}
/// <inheritdoc />
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
{
_handle?.AcceptToken(ctx, token);
}
/// <inheritdoc />
public void Dispose()
{
_handle?.Dispose();
_handle = null;
}
}

View File

@ -1,34 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Tokens;
using llama_token = Int32;
/// <summary>
/// Processes token logits before sampling, applying penalties to certain tokens
/// </summary>
public interface ITokenDataProcessor
: IDisposable
{
/// <summary>
/// Process token logits in a LLamaTokenDataArray
/// </summary>
/// <param name="ctx">The context this is operating in</param>
/// <param name="tokens">The token data array to process</param>
/// <param name="lastTokens">The most recent tokens output</param>
/// <returns>LLamaTokenDataArray, created from logits</returns>
void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<llama_token> lastTokens);
/// <summary>
/// Inform this process when a token is accepted by the model
/// </summary>
/// <param name="ctx"></param>
/// <param name="token"></param>
void AcceptToken(SafeLLamaContextHandle ctx, int token);
/// <summary>
/// Reset all internal sampling state
/// </summary>
void Reset();
}

View File

@ -1,42 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Tokens;
/// <summary>
/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
/// </summary>
public sealed class LocallyTypicalSampling
: ITokenDataProcessor
{
/// <summary>
/// P value for locally typical sampling
/// </summary>
public float P { get; set; }
/// <summary>
/// Minimum number of tokens to keep
/// </summary>
public ulong MinKeep { get; set; } = 1;
/// <inheritdoc />
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
{
tokens.LocallyTypical(ctx, P, MinKeep);
}
/// <inheritdoc />
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
{
}
/// <inheritdoc />
public void Reset()
{
}
/// <inheritdoc />
public void Dispose()
{
}
}

View File

@ -1,42 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Tokens;
/// <summary>
/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
/// </summary>
public sealed class MinPSampling
: ITokenDataProcessor
{
/// <summary>
/// All tokens with probability greater than this will be kept
/// </summary>
public float P { get; set; }
/// <summary>
/// Minimum number of tokens to keep
/// </summary>
public ulong MinKeep { get; set; } = 1;
/// <inheritdoc />
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
{
tokens.MinP(ctx, P, MinKeep);
}
/// <inheritdoc />
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
{
}
/// <inheritdoc />
public void Reset()
{
}
/// <inheritdoc />
public void Dispose()
{
}
}

View File

@ -1,77 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Tokens;
/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// </summary>
public sealed class RepetitionPenalty
: ITokenDataProcessor
{
private float _alphaFreq;
private float _alphaPresence;
/// <summary>
/// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
/// </summary>
public float RepeatPenalty { get; set; } = 1.1f;
/// <summary>
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
/// </summary>
public float AlphaFrequency
{
get => _alphaFreq;
set
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaFreq = value;
}
}
/// <summary>
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
/// text so far, increasing the model's likelihood to talk about new topics.
/// </summary>
public float AlphaPresence
{
get => _alphaPresence;
set
{
if (value < -2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
if (value > 2)
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
_alphaPresence = value;
}
}
/// <inheritdoc />
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
{
tokens.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);
}
/// <inheritdoc />
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
{
}
/// <inheritdoc />
public void Reset()
{
}
/// <inheritdoc />
public void Dispose()
{
}
}

View File

@ -1,42 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Tokens;
/// <summary>
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
/// </summary>
public sealed class TailFreeSampling
: ITokenDataProcessor
{
/// <summary>
/// Z value for tail free sampling
/// </summary>
public float Z { get; set; }
/// <summary>
/// Minimum number of tokens to keep
/// </summary>
public ulong MinKeep { get; set; } = 1;
/// <inheritdoc />
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
{
tokens.TailFree(ctx, Z, MinKeep);
}
/// <inheritdoc />
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
{
}
/// <inheritdoc />
public void Reset()
{
}
/// <inheritdoc />
public void Dispose()
{
}
}

View File

@ -1,38 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Tokens;
/// <summary>
/// Sample with temperature.
/// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
/// </summary>
public sealed class TemperatureSampling
: ITokenDataProcessor
{
/// <summary>
/// Temperature value to apply
/// </summary>
public float Temperature { get; set; } = 0.5f;
/// <inheritdoc />
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
{
tokens.Temperature(ctx, Temperature);
}
/// <inheritdoc />
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
{
}
/// <inheritdoc />
public void Reset()
{
}
/// <inheritdoc />
public void Dispose()
{
}
}

View File

@ -1,38 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Tokens;
/// <summary>
/// Sample with TopK, removing all by the K most likely tokens.
/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// </summary>
public sealed class TopKSampling
: ITokenDataProcessor
{
/// <summary>
/// Number of tokens to keep
/// </summary>
public int Count { get; set; }
/// <inheritdoc />
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
{
tokens.TopK(ctx, Count);
}
/// <inheritdoc />
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
{
}
/// <inheritdoc />
public void Reset()
{
}
/// <inheritdoc />
public void Dispose()
{
}
}

View File

@ -1,42 +0,0 @@
using System;
using LLama.Native;
namespace LLama.Sampling.Tokens;
/// <summary>
/// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// </summary>
public sealed class TopPSampling
: ITokenDataProcessor
{
/// <summary>
/// P valies for TopP
/// </summary>
public float P { get; set; }
/// <summary>
/// Minimum number of tokens to keep
/// </summary>
public ulong MinKeep { get; set; } = 1;
/// <inheritdoc />
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
{
tokens.TopP(ctx, P, MinKeep);
}
/// <inheritdoc />
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
{
}
/// <inheritdoc />
public void Reset()
{
}
/// <inheritdoc />
public void Dispose()
{
}
}