- Removed the object wrappers and configurable pipeline, they can be better written in code.
- Added BaseSamplingPipeline which provides a base impl of `ISamplingPipeline` - Added `DefaultSamplingPipeline` which mimics normal llama.cpp sampling
This commit is contained in:
parent
3afc007499
commit
835958398c
|
@ -1,5 +1,4 @@
|
|||
using System.Text;
|
||||
using LLama.Exceptions;
|
||||
using LLama.Exceptions;
|
||||
using LLama.Native;
|
||||
using LLama.Grammars;
|
||||
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
using System.Diagnostics;
|
||||
using LLama.Common;
|
||||
using LLama.Sampling;
|
||||
using LLama.Sampling.Logits;
|
||||
using LLama.Sampling.Selection;
|
||||
using LLama.Sampling.Tokens;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace LLama.Unittest
|
||||
|
@ -35,40 +32,12 @@ namespace LLama.Unittest
|
|||
public async Task Stateless()
|
||||
{
|
||||
// Create a custom pipeline that mimics the default pipeline
|
||||
var pipeline = new ConfigurableSamplingPipeline()
|
||||
{
|
||||
ProtectedLogits =
|
||||
{
|
||||
_weights.NewlineToken,
|
||||
_weights.BeginningOfSentenceToken,
|
||||
_weights.EndOfSentenceToken
|
||||
},
|
||||
LogitProcessors =
|
||||
{
|
||||
new LogitBias
|
||||
{
|
||||
Biases =
|
||||
{
|
||||
{ _weights.NewlineToken, 1000 }, // This is an insane bias, but because newline is a protected logit it will do nothing!
|
||||
{ 42, 0f },
|
||||
}
|
||||
}
|
||||
},
|
||||
TokenDataProcessors =
|
||||
{
|
||||
new TailFreeSampling { Z = 1 },
|
||||
new LocallyTypicalSampling { P = 1 },
|
||||
new TopPSampling { P = 0.95f },
|
||||
new MinPSampling { P = 0.05f },
|
||||
new TemperatureSampling { Temperature = 0.8f },
|
||||
},
|
||||
Selector = new StandardSelection(),
|
||||
};
|
||||
var pipeline = new DefaultSamplingPipeline();
|
||||
|
||||
var executor = new StatelessExecutor(_weights, _params);
|
||||
|
||||
const string question = "Question. what is a cat?\nAnswer: ";
|
||||
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline};
|
||||
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };
|
||||
|
||||
var timer = new Stopwatch();
|
||||
timer.Start();
|
||||
|
|
|
@ -46,14 +46,41 @@ namespace LLama.Native
|
|||
return new LLamaTokenDataArray(candidates);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Overwrite the logit values for all given tokens
|
||||
/// </summary>
|
||||
/// <param name="values">tuples of token and logit value to overwrite</param>
|
||||
public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values)
|
||||
{
|
||||
if (values.Length == 0)
|
||||
return;
|
||||
|
||||
var dataSpan = data.Span;
|
||||
foreach (var (token, value) in values)
|
||||
{
|
||||
for (var i = 0; i < data.Length; i++)
|
||||
{
|
||||
if (dataSpan[i].id == token)
|
||||
{
|
||||
dataSpan[i].logit = value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
sorted = false;
|
||||
}
|
||||
|
||||
#region sampling
|
||||
/// <summary>
|
||||
/// Apply grammar rules to candidate tokens
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="grammar"></param>
|
||||
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
|
||||
public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar)
|
||||
{
|
||||
if (grammar == null)
|
||||
return;
|
||||
|
||||
using (LLamaTokenDataArrayNative.Create(this, out var st))
|
||||
{
|
||||
NativeApi.llama_sample_grammar(ctx, ref st, grammar);
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling;
|
||||
|
||||
/// <summary>
|
||||
/// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`.
|
||||
/// </summary>
|
||||
public abstract class BaseSamplingPipeline
|
||||
: ISamplingPipeline
|
||||
{
|
||||
private int _savedLogitsCount;
|
||||
private (int index, float logit)[]? _savedLogits;
|
||||
|
||||
/// <inheritdoc/>
|
||||
public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
var protectedLogits = GetProtectedTokens(ctx);
|
||||
_savedLogitsCount = protectedLogits.Count;
|
||||
_savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount);
|
||||
try
|
||||
{
|
||||
// Save the values of protected logits
|
||||
for (var i = 0; i < protectedLogits.Count; i++)
|
||||
{
|
||||
var index = protectedLogits[i];
|
||||
var value = logits[index];
|
||||
_savedLogits[i] = (index, value);
|
||||
}
|
||||
|
||||
// Process raw logits
|
||||
ProcessLogits(ctx, logits, lastTokens);
|
||||
|
||||
// Automatically restore saved logit values after processing
|
||||
RestoreProtectedTokens(logits);
|
||||
|
||||
// Convert logits into token candidates
|
||||
var candidates = LLamaTokenDataArray.Create(logits);
|
||||
|
||||
// Process token data array
|
||||
ProcessTokenDataArray(ctx, candidates, lastTokens);
|
||||
|
||||
// Choose the final value
|
||||
return ChooseToken(ctx, candidates);
|
||||
}
|
||||
finally
|
||||
{
|
||||
ArrayPool<(int, float)>.Shared.Return(_savedLogits);
|
||||
_savedLogits = null;
|
||||
_savedLogitsCount = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#region protected tokens
|
||||
/// <summary>
|
||||
/// Get all of the "protected" tokens that cannot be changed by ProcessLogits
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
protected abstract IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx);
|
||||
|
||||
/// <summary>
|
||||
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
|
||||
/// </summary>
|
||||
/// <param name="logits"></param>
|
||||
protected void RestoreProtectedTokens(Span<float> logits)
|
||||
{
|
||||
if (_savedLogits == null)
|
||||
return;
|
||||
|
||||
// The array may be bigger than necessary, get a span of the valid bit
|
||||
var saved = _savedLogits.AsSpan(0, _savedLogitsCount);
|
||||
|
||||
// Restore the values of protected logits
|
||||
for (var i = 0; i < saved.Length; i++)
|
||||
logits[saved[i].index] = saved[i].logit;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits
|
||||
/// </summary>
|
||||
/// <param name="candidates"></param>
|
||||
protected void RestoreProtectedTokens(LLamaTokenDataArray candidates)
|
||||
{
|
||||
if (_savedLogits == null || _savedLogits.Length == 0)
|
||||
return;
|
||||
|
||||
candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount));
|
||||
}
|
||||
#endregion
|
||||
|
||||
/// <summary>
|
||||
/// Process the raw logit values
|
||||
/// </summary>
|
||||
/// <param name="ctx">The context being sampled from</param>
|
||||
/// <param name="logits">The logits produced by the model</param>
|
||||
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
|
||||
protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Process the LLamaTokenDataArray and select a single token
|
||||
/// </summary>
|
||||
/// <param name="ctx">The context being sampled from</param>
|
||||
/// <param name="candidates">The LLamaTokenDataArray data produced by the model</param>
|
||||
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
|
||||
/// <returns></returns>
|
||||
protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Choose the final token from the candidates
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="candidates"></param>
|
||||
/// <returns></returns>
|
||||
protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates);
|
||||
|
||||
/// <inheritdoc/>
|
||||
public virtual void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public virtual void Dispose()
|
||||
{
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using LLama.Extensions;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling;
|
||||
|
||||
/// <summary>
|
||||
/// An implementation of ISamplePipeline which mimics the default llama.cpp sampling
|
||||
/// </summary>
|
||||
public sealed class DefaultSamplingPipeline
|
||||
: BaseSamplingPipeline
|
||||
{
|
||||
/// <summary>
|
||||
/// Bias values to add to certain logits
|
||||
/// </summary>
|
||||
public Dictionary<int, float> LogitBias { get; } = new();
|
||||
|
||||
/// <summary>
|
||||
/// Grammar to constrain valid tokens
|
||||
/// </summary>
|
||||
public SafeLLamaGrammarHandle? Grammar { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
|
||||
/// </summary>
|
||||
public float RepeatPenalty { get; set; } = 1.1f;
|
||||
|
||||
/// <summary>
|
||||
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
|
||||
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
|
||||
/// </summary>
|
||||
public float AlphaFrequency
|
||||
{
|
||||
get => _alphaFreq;
|
||||
set
|
||||
{
|
||||
if (value < -2)
|
||||
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
|
||||
if (value > 2)
|
||||
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
|
||||
_alphaFreq = value;
|
||||
}
|
||||
}
|
||||
private float _alphaFreq = 0.1f;
|
||||
|
||||
/// <summary>
|
||||
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
|
||||
/// text so far, increasing the model's likelihood to talk about new topics.
|
||||
/// </summary>
|
||||
public float AlphaPresence
|
||||
{
|
||||
get => _alphaPresence;
|
||||
set
|
||||
{
|
||||
if (value < -2)
|
||||
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
|
||||
if (value > 2)
|
||||
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
|
||||
_alphaPresence = value;
|
||||
}
|
||||
}
|
||||
private float _alphaPresence = 0.1f;
|
||||
|
||||
/// <summary>
|
||||
/// Temperature to apply (higher temperature is more "creative")
|
||||
/// </summary>
|
||||
public float Temperature { get; set; } = 0.75f;
|
||||
|
||||
/// <summary>
|
||||
/// Number of tokens to keep in TopK sampling
|
||||
/// </summary>
|
||||
public int TopK { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Z value for tail free sampling
|
||||
/// </summary>
|
||||
public float TailFreeZ { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// P value for locally typical sampling
|
||||
/// </summary>
|
||||
public float TypicalP { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// P value for TopP sampling
|
||||
/// </summary>
|
||||
public float TopP { get; set; } = 1f;
|
||||
|
||||
/// <summary>
|
||||
/// P value for MinP sampling
|
||||
/// </summary>
|
||||
public float MinP { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Whether the newline value should be protected from being modified by logit bias and repeat penalty
|
||||
/// </summary>
|
||||
public bool PenalizeNewline { get; set; } = false;
|
||||
|
||||
private readonly int[] _newlineToken = new int[1];
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override IReadOnlyList<int> GetProtectedTokens(SafeLLamaContextHandle ctx)
|
||||
{
|
||||
if (PenalizeNewline)
|
||||
return Array.Empty<int>();
|
||||
|
||||
_newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle);
|
||||
return _newlineToken;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
foreach (var (key, value) in LogitBias)
|
||||
logits[key] += value;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
// Apply penalties to candidates
|
||||
candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);
|
||||
|
||||
// Restore protected tokens, so they are not affected by repetition penalties
|
||||
RestoreProtectedTokens(candidates);
|
||||
|
||||
// Apply the normal llama.cpp pipeline
|
||||
candidates.ApplyGrammar(ctx, Grammar);
|
||||
candidates.TopK(ctx, TopK);
|
||||
candidates.TailFree(ctx, TailFreeZ);
|
||||
candidates.LocallyTypical(ctx, TypicalP);
|
||||
candidates.TopP(ctx, TopP);
|
||||
candidates.MinP(ctx, MinP);
|
||||
candidates.Temperature(ctx, Temperature);
|
||||
var id = candidates.SampleToken(ctx);
|
||||
|
||||
Grammar?.AcceptToken(ctx, id);
|
||||
return id;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
|
||||
{
|
||||
return candidates.SampleToken(ctx);
|
||||
}
|
||||
}
|
|
@ -3,14 +3,11 @@ using System.Buffers;
|
|||
using System.Collections.Generic;
|
||||
using System.Runtime.InteropServices;
|
||||
using LLama.Native;
|
||||
using LLama.Sampling.Logits;
|
||||
using LLama.Sampling.Selection;
|
||||
using LLama.Sampling.Tokens;
|
||||
|
||||
namespace LLama.Sampling;
|
||||
|
||||
/// <summary>
|
||||
/// Convert a span of logits into a single sampled token
|
||||
/// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process.
|
||||
/// </summary>
|
||||
public interface ISamplingPipeline
|
||||
: IDisposable
|
||||
|
@ -61,101 +58,4 @@ public static class ISamplingPipelineExtensions
|
|||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Simple implementation of `ISamplingPipeline`, applies processors in order every time
|
||||
/// </summary>
|
||||
public sealed class ConfigurableSamplingPipeline
|
||||
: ISamplingPipeline
|
||||
{
|
||||
/// <summary>
|
||||
/// Logit processors to apply in this pipeline
|
||||
/// </summary>
|
||||
public IList<ILogitProcessor> LogitProcessors { get; } = new List<ILogitProcessor>();
|
||||
|
||||
/// <summary>
|
||||
/// Logits values which will not be changed by the logit processors
|
||||
/// </summary>
|
||||
public IList<int> ProtectedLogits { get; } = new List<int>();
|
||||
|
||||
/// <summary>
|
||||
/// Token data processors to apply in this pipeline
|
||||
/// </summary>
|
||||
public IList<ITokenDataProcessor> TokenDataProcessors { get; } = new List<ITokenDataProcessor>();
|
||||
|
||||
/// <summary>
|
||||
/// The selector to choose the final token
|
||||
/// </summary>
|
||||
public ITokenSelector Selector { get; set; } = new StandardSelection();
|
||||
|
||||
/// <inheritdoc />
|
||||
public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
var savedLogitsCount = ProtectedLogits.Count;
|
||||
var savedLogitValues = ArrayPool<float>.Shared.Rent(savedLogitsCount);
|
||||
var savedLogitIndices = ArrayPool<int>.Shared.Rent(savedLogitsCount);
|
||||
try
|
||||
{
|
||||
// Save the values of protected logits
|
||||
for (var i = 0; i < ProtectedLogits.Count; i++)
|
||||
{
|
||||
savedLogitValues[i] = logits[ProtectedLogits[i]];
|
||||
savedLogitIndices[i] = ProtectedLogits[i];
|
||||
}
|
||||
|
||||
// Modify raw logits
|
||||
foreach (var logitProcessor in LogitProcessors)
|
||||
logitProcessor.ProcessLogits(ctx, logits, lastTokens);
|
||||
|
||||
// Restore the values of protected logits
|
||||
for (var i = 0; i < savedLogitsCount; i++)
|
||||
logits[savedLogitIndices[i]] = savedLogitValues[i];
|
||||
}
|
||||
finally
|
||||
{
|
||||
ArrayPool<float>.Shared.Return(savedLogitValues);
|
||||
ArrayPool<int>.Shared.Return(savedLogitIndices);
|
||||
}
|
||||
|
||||
// Convert logits into token candidates
|
||||
var candidates_p = LLamaTokenDataArray.Create(logits);
|
||||
|
||||
// Process token candidates
|
||||
foreach (var tokenDataProcessor in TokenDataProcessors)
|
||||
tokenDataProcessor.ProcessTokens(ctx, candidates_p, lastTokens);
|
||||
|
||||
// Select a token
|
||||
var token = Selector.Select(ctx, candidates_p, lastTokens);
|
||||
|
||||
// Tell processors what was selected
|
||||
foreach (var logitProcessor in LogitProcessors)
|
||||
logitProcessor.AcceptToken(ctx, token);
|
||||
foreach (var tokenDataProcessor in TokenDataProcessors)
|
||||
tokenDataProcessor.AcceptToken(ctx, token);
|
||||
|
||||
return token;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
foreach (var logitProcessor in LogitProcessors)
|
||||
logitProcessor.Reset();
|
||||
foreach (var tokenDataProcessor in TokenDataProcessors)
|
||||
tokenDataProcessor.Reset();
|
||||
|
||||
Selector.Reset();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
foreach (var logitProcessor in LogitProcessors)
|
||||
logitProcessor.Dispose();
|
||||
foreach (var tokenDataProcessor in TokenDataProcessors)
|
||||
tokenDataProcessor.Dispose();
|
||||
|
||||
Selector.Dispose();
|
||||
}
|
||||
}
|
|
@ -1,34 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Logits;
|
||||
|
||||
using llama_token = Int32;
|
||||
|
||||
/// <summary>
|
||||
/// Processes raw logits before sampling, applying penalties to certain tokens
|
||||
/// </summary>
|
||||
public interface ILogitProcessor
|
||||
: IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Process raw logits, indexed by llama_token
|
||||
/// </summary>
|
||||
/// <param name="ctx">The context this is operating in</param>
|
||||
/// <param name="logits">The token data array to process</param>
|
||||
/// <param name="lastTokens">The most recent tokens output</param>
|
||||
/// <returns>LLamaTokenDataArray, created from logits</returns>
|
||||
void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<llama_token> lastTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Inform this process when a token is accepted by the model
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="token"></param>
|
||||
void AcceptToken(SafeLLamaContextHandle ctx, int token);
|
||||
|
||||
/// <summary>
|
||||
/// Reset all internal sampling state
|
||||
/// </summary>
|
||||
void Reset();
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Logits;
|
||||
|
||||
/// <summary>
|
||||
/// Add a bias directly to logit values
|
||||
/// </summary>
|
||||
public sealed class LogitBias
|
||||
: ILogitProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// Biases to apply, token -> bias
|
||||
/// </summary>
|
||||
public IDictionary<int, float> Biases { get; } = new Dictionary<int, float>();
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
foreach (var kvp in Biases)
|
||||
logits[kvp.Key] += kvp.Value;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Selection;
|
||||
|
||||
/// <summary>
|
||||
/// Select the most likely token
|
||||
/// </summary>
|
||||
public sealed class GreedySelection
|
||||
: ITokenSelector
|
||||
{
|
||||
/// <inheritdoc />
|
||||
public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
return candidates.SampleTokenGreedy(ctx);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Selection;
|
||||
|
||||
/// <summary>
|
||||
/// Select a single token from a set of possibilities
|
||||
/// </summary>
|
||||
public interface ITokenSelector
|
||||
: IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Select a single token
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="candidates"></param>
|
||||
/// <param name="lastTokens"></param>
|
||||
/// <returns></returns>
|
||||
int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Reset the state
|
||||
/// </summary>
|
||||
void Reset();
|
||||
}
|
|
@ -1,65 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Selection;
|
||||
|
||||
/// <summary>
|
||||
/// Select a token using Mirostat sampling.
|
||||
/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966.
|
||||
/// </summary>
|
||||
public sealed class Mirostat2Selection
|
||||
: ITokenSelector
|
||||
{
|
||||
private float _mu;
|
||||
|
||||
/// <summary>
|
||||
/// Current value of Mu, updated based on the difference between target surprise and actual surprise
|
||||
/// </summary>
|
||||
public float Mu
|
||||
{
|
||||
get => _mu;
|
||||
set => _mu = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The target cross-entropy (or surprise) value you want to achieve for the generated text.
|
||||
/// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text
|
||||
/// </summary>
|
||||
public float Tau { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word.
|
||||
/// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// </summary>
|
||||
public float Eta { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Create a new Mirostat 2.0 sampler
|
||||
/// </summary>
|
||||
/// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text.
|
||||
/// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text</param>
|
||||
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word.
|
||||
/// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
|
||||
public Mirostat2Selection(float tau, float eta)
|
||||
{
|
||||
Tau = tau;
|
||||
Eta = eta;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
_mu = 2 * Tau;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Selection;
|
||||
|
||||
/// <summary>
|
||||
/// Select a token using Mirostat sampling.
|
||||
/// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966.
|
||||
/// </summary>
|
||||
public sealed class MirostatSelection
|
||||
: ITokenSelector
|
||||
{
|
||||
private float _mu;
|
||||
|
||||
/// <summary>
|
||||
/// Current value of Mu, updated based on the difference between target surprise and actual surprise
|
||||
/// </summary>
|
||||
public float Mu
|
||||
{
|
||||
get => _mu;
|
||||
set => _mu = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The target cross-entropy (or surprise) value you want to achieve for the generated text.
|
||||
/// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text
|
||||
/// </summary>
|
||||
public float Tau { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word.
|
||||
/// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// </summary>
|
||||
public float Eta { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn
|
||||
/// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects
|
||||
/// the performance of the algorithm.
|
||||
/// </summary>
|
||||
public int M { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Create a new Mirostat 2.0 sampler
|
||||
/// </summary>
|
||||
/// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text.
|
||||
/// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text</param>
|
||||
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word.
|
||||
/// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
|
||||
/// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn
|
||||
/// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects
|
||||
/// the performance of the algorithm.</param>
|
||||
public MirostatSelection(float tau, float eta, int m = 100)
|
||||
{
|
||||
Tau = tau;
|
||||
Eta = eta;
|
||||
M = m;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
return candidates.SampleTokenMirostat(ctx, Tau, Eta, M, ref _mu);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
_mu = 2 * Tau;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Selection;
|
||||
|
||||
/// <summary>
|
||||
/// Select from all possible tokens according to their probability
|
||||
/// </summary>
|
||||
public sealed class StandardSelection
|
||||
: ITokenSelector
|
||||
{
|
||||
/// <inheritdoc />
|
||||
public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
return candidates.SampleToken(ctx);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
using System;
|
||||
using LLama.Grammars;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Apply a grammar to prevent sampling tokens which do not match the grammar
|
||||
/// </summary>
|
||||
public sealed class GrammarSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
private SafeLLamaGrammarHandle? _handle;
|
||||
|
||||
/// <summary>
|
||||
/// Grammar to use for sampling
|
||||
/// </summary>
|
||||
public Grammar? Grammar { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Create a new
|
||||
/// </summary>
|
||||
/// <param name="grammar"></param>
|
||||
public GrammarSampling(Grammar grammar)
|
||||
{
|
||||
Grammar = grammar;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
_handle?.Dispose();
|
||||
_handle = null;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
// Create a new grammar instance if necessary
|
||||
_handle ??= Grammar?.CreateInstance();
|
||||
|
||||
// Apply it
|
||||
if (_handle != null)
|
||||
tokens.ApplyGrammar(ctx, _handle);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
_handle?.AcceptToken(ctx, token);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
_handle?.Dispose();
|
||||
_handle = null;
|
||||
}
|
||||
}
|
|
@ -1,34 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
using llama_token = Int32;
|
||||
|
||||
/// <summary>
|
||||
/// Processes token logits before sampling, applying penalties to certain tokens
|
||||
/// </summary>
|
||||
public interface ITokenDataProcessor
|
||||
: IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Process token logits in a LLamaTokenDataArray
|
||||
/// </summary>
|
||||
/// <param name="ctx">The context this is operating in</param>
|
||||
/// <param name="tokens">The token data array to process</param>
|
||||
/// <param name="lastTokens">The most recent tokens output</param>
|
||||
/// <returns>LLamaTokenDataArray, created from logits</returns>
|
||||
void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<llama_token> lastTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Inform this process when a token is accepted by the model
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="token"></param>
|
||||
void AcceptToken(SafeLLamaContextHandle ctx, int token);
|
||||
|
||||
/// <summary>
|
||||
/// Reset all internal sampling state
|
||||
/// </summary>
|
||||
void Reset();
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
/// </summary>
|
||||
public sealed class LocallyTypicalSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// P value for locally typical sampling
|
||||
/// </summary>
|
||||
public float P { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Minimum number of tokens to keep
|
||||
/// </summary>
|
||||
public ulong MinKeep { get; set; } = 1;
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.LocallyTypical(ctx, P, MinKeep);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
||||
/// </summary>
|
||||
public sealed class MinPSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// All tokens with probability greater than this will be kept
|
||||
/// </summary>
|
||||
public float P { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Minimum number of tokens to keep
|
||||
/// </summary>
|
||||
public ulong MinKeep { get; set; } = 1;
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.MinP(ctx, P, MinKeep);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -1,77 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
/// </summary>
|
||||
public sealed class RepetitionPenalty
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
private float _alphaFreq;
|
||||
private float _alphaPresence;
|
||||
|
||||
/// <summary>
|
||||
/// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
|
||||
/// </summary>
|
||||
public float RepeatPenalty { get; set; } = 1.1f;
|
||||
|
||||
/// <summary>
|
||||
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
|
||||
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
|
||||
/// </summary>
|
||||
public float AlphaFrequency
|
||||
{
|
||||
get => _alphaFreq;
|
||||
set
|
||||
{
|
||||
if (value < -2)
|
||||
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
|
||||
if (value > 2)
|
||||
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
|
||||
_alphaFreq = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
|
||||
/// text so far, increasing the model's likelihood to talk about new topics.
|
||||
/// </summary>
|
||||
public float AlphaPresence
|
||||
{
|
||||
get => _alphaPresence;
|
||||
set
|
||||
{
|
||||
if (value < -2)
|
||||
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
|
||||
if (value > 2)
|
||||
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
|
||||
_alphaPresence = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
/// </summary>
|
||||
public sealed class TailFreeSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// Z value for tail free sampling
|
||||
/// </summary>
|
||||
public float Z { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Minimum number of tokens to keep
|
||||
/// </summary>
|
||||
public ulong MinKeep { get; set; } = 1;
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.TailFree(ctx, Z, MinKeep);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Sample with temperature.
|
||||
/// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
|
||||
/// </summary>
|
||||
public sealed class TemperatureSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// Temperature value to apply
|
||||
/// </summary>
|
||||
public float Temperature { get; set; } = 0.5f;
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.Temperature(ctx, Temperature);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Sample with TopK, removing all by the K most likely tokens.
|
||||
/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
/// </summary>
|
||||
public sealed class TopKSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// Number of tokens to keep
|
||||
/// </summary>
|
||||
public int Count { get; set; }
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.TopK(ctx, Count);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
/// </summary>
|
||||
public sealed class TopPSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// P valies for TopP
|
||||
/// </summary>
|
||||
public float P { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Minimum number of tokens to keep
|
||||
/// </summary>
|
||||
public ulong MinKeep { get; set; } = 1;
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.TopP(ctx, P, MinKeep);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue