Initial pass at a new sampling pipeline
This commit is contained in:
parent
ac3cc7c039
commit
33358124db
|
@ -145,15 +145,17 @@ namespace LLama.Native
|
|||
/// <param name="penalty_repeat"></param>
|
||||
/// <param name="penalty_freq"></param>
|
||||
/// <param name="penalty_present"></param>
|
||||
public void RepetitionPenalty(SafeLLamaContextHandle context, Memory<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
|
||||
public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan<llama_token> last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
using (LLamaTokenDataArrayNative.Create(this, out var st))
|
||||
using (var last_tokens_handle = last_tokens.Pin())
|
||||
{
|
||||
NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
|
||||
sorted = st.sorted;
|
||||
fixed (int* last_tokens_handle = last_tokens)
|
||||
{
|
||||
NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
|
||||
sorted = st.sorted;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using LLama.Native;
|
||||
using LLama.Sampling.Logits;
|
||||
using LLama.Sampling.Selection;
|
||||
using LLama.Sampling.Tokens;
|
||||
|
||||
namespace LLama.Sampling;
|
||||
|
||||
/// <summary>
|
||||
/// Convert a span of logits into a single sampled token
|
||||
/// </summary>
|
||||
public interface ISamplingPipeline
|
||||
: IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Sample a single token from the given logits
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="logits"></param>
|
||||
/// <param name="lastTokens"></param>
|
||||
/// <returns></returns>
|
||||
int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Reset all internal state of the sampling pipeline
|
||||
/// </summary>
|
||||
void Reset();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Simple implementation of `ISamplingPipeline`, applies processors in order every time
|
||||
/// </summary>
|
||||
public sealed class BasicSamplingPipeline
|
||||
: ISamplingPipeline
|
||||
{
|
||||
/// <summary>
|
||||
/// Logit processors to apply in this pipeline
|
||||
/// </summary>
|
||||
public IList<ILogitProcessor> LogitProcessors { get; } = new List<ILogitProcessor>();
|
||||
|
||||
/// <summary>
|
||||
/// Token data processors to apply in this pipeline
|
||||
/// </summary>
|
||||
public IList<ITokenDataProcessor> TokenDataProcessors { get; } = new List<ITokenDataProcessor>();
|
||||
|
||||
/// <summary>
|
||||
/// The selector to choose the final token
|
||||
/// </summary>
|
||||
public ITokenSelector Selector { get; set; } = new StandardSelection();
|
||||
|
||||
/// <inheritdoc />
|
||||
public int Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
// Modify raw logits
|
||||
foreach (var logitProcessor in LogitProcessors)
|
||||
logitProcessor.ProcessLogits(ctx, logits, lastTokens);
|
||||
|
||||
// Convert logits into token candidates
|
||||
var candidates_p = LLamaTokenDataArray.Create(logits);
|
||||
|
||||
// Process token candidates
|
||||
foreach (var tokenDataProcessor in TokenDataProcessors)
|
||||
tokenDataProcessor.ProcessTokens(ctx, candidates_p, lastTokens);
|
||||
|
||||
// Select a token
|
||||
var token = Selector.Select(ctx, candidates_p, lastTokens);
|
||||
|
||||
// Tell processors what was selected
|
||||
foreach (var logitProcessor in LogitProcessors)
|
||||
logitProcessor.AcceptToken(ctx, token);
|
||||
foreach (var tokenDataProcessor in TokenDataProcessors)
|
||||
tokenDataProcessor.AcceptToken(ctx, token);
|
||||
|
||||
return token;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
foreach (var logitProcessor in LogitProcessors)
|
||||
logitProcessor.Reset();
|
||||
foreach (var tokenDataProcessor in TokenDataProcessors)
|
||||
tokenDataProcessor.Reset();
|
||||
|
||||
Selector.Reset();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
foreach (var logitProcessor in LogitProcessors)
|
||||
logitProcessor.Dispose();
|
||||
foreach (var tokenDataProcessor in TokenDataProcessors)
|
||||
tokenDataProcessor.Dispose();
|
||||
|
||||
Selector.Dispose();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Logits;
|
||||
|
||||
using llama_token = Int32;
|
||||
|
||||
/// <summary>
|
||||
/// Processes raw logits before sampling, applying penalties to certain tokens
|
||||
/// </summary>
|
||||
public interface ILogitProcessor
|
||||
: IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Process raw logits, indexed by llama_token
|
||||
/// </summary>
|
||||
/// <param name="ctx">The context this is operating in</param>
|
||||
/// <param name="logits">The token data array to process</param>
|
||||
/// <param name="lastTokens">The most recent tokens output</param>
|
||||
/// <returns>LLamaTokenDataArray, created from logits</returns>
|
||||
void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<llama_token> lastTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Inform this process when a token is accepted by the model
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="token"></param>
|
||||
void AcceptToken(SafeLLamaContextHandle ctx, int token);
|
||||
|
||||
/// <summary>
|
||||
/// Reset all internal sampling state
|
||||
/// </summary>
|
||||
void Reset();
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Logits;
|
||||
|
||||
/// <summary>
|
||||
/// Add a bias directly to logit values
|
||||
/// </summary>
|
||||
public sealed class LogitBias
|
||||
: ILogitProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// Biases to apply, token -> bias
|
||||
/// </summary>
|
||||
public IDictionary<int, float> Biases { get; } = new Dictionary<int, float>();
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
foreach (var kvp in Biases)
|
||||
logits[kvp.Key] += kvp.Value;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Logits;
|
||||
|
||||
/// <summary>
|
||||
/// Save certain logit values
|
||||
/// </summary>
|
||||
public sealed class SaveLogitValues
|
||||
: ILogitProcessor
|
||||
{
|
||||
private readonly Dictionary<int, float> _saved = new();
|
||||
|
||||
/// <summary>
|
||||
/// Logits to save
|
||||
/// </summary>
|
||||
public ISet<int> Logits { get; } = new HashSet<int>();
|
||||
|
||||
/// <summary>
|
||||
/// Saved logit values
|
||||
/// </summary>
|
||||
public IReadOnlyDictionary<int, float> Values => _saved;
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
_saved.Clear();
|
||||
foreach (var logit in Logits)
|
||||
_saved[logit] = logits[logit];
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
_saved.Clear();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get a logit processor that overwrite the logit values with the values saved here
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public ILogitProcessor GetWriter()
|
||||
{
|
||||
return new LoadLogitValues(_saved);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Overwrite certain logit values
|
||||
/// </summary>
|
||||
public sealed class LoadLogitValues
|
||||
: ILogitProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// Logits to overwrite, token -> logit
|
||||
/// </summary>
|
||||
public IDictionary<int, float> Values { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Create a new LoadLogitValues
|
||||
/// </summary>
|
||||
/// <param name="values">Source for values to overwrite</param>
|
||||
public LoadLogitValues(Dictionary<int, float>? values = null)
|
||||
{
|
||||
Values = values ?? new Dictionary<int, float>();
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
foreach (var logit in Values)
|
||||
logits[logit.Key] = logit.Value;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Selection;
|
||||
|
||||
/// <summary>
|
||||
/// Select the most likely token
|
||||
/// </summary>
|
||||
public sealed class GreedySelection
|
||||
: ITokenSelector
|
||||
{
|
||||
/// <inheritdoc />
|
||||
public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
return candidates.SampleTokenGreedy(ctx);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Selection;
|
||||
|
||||
/// <summary>
|
||||
/// Select a single token from a set of possibilities
|
||||
/// </summary>
|
||||
public interface ITokenSelector
|
||||
: IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Select a single token
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="candidates"></param>
|
||||
/// <param name="lastTokens"></param>
|
||||
/// <returns></returns>
|
||||
int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Reset the state
|
||||
/// </summary>
|
||||
void Reset();
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Selection;
|
||||
|
||||
/// <summary>
|
||||
/// Select a token using Mirostat sampling.
|
||||
/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966.
|
||||
/// </summary>
|
||||
public sealed class Mirostat2Selection
|
||||
: ITokenSelector
|
||||
{
|
||||
private float _mu;
|
||||
|
||||
/// <summary>
|
||||
/// Current value of Mu, updated based on the difference between target surprise and actual surprise
|
||||
/// </summary>
|
||||
public float Mu
|
||||
{
|
||||
get => _mu;
|
||||
set => _mu = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The target cross-entropy (or surprise) value you want to achieve for the generated text.
|
||||
/// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text
|
||||
/// </summary>
|
||||
public float Tau { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word.
|
||||
/// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// </summary>
|
||||
public float Eta { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Create a new Mirostat 2.0 sampler
|
||||
/// </summary>
|
||||
/// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text.
|
||||
/// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text</param>
|
||||
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word.
|
||||
/// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
|
||||
public Mirostat2Selection(float tau, float eta)
|
||||
{
|
||||
Tau = tau;
|
||||
Eta = eta;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
_mu = 2 * Tau;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Selection;
|
||||
|
||||
/// <summary>
|
||||
/// Select a token using Mirostat sampling.
|
||||
/// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966.
|
||||
/// </summary>
|
||||
public sealed class MirostatSelection
|
||||
: ITokenSelector
|
||||
{
|
||||
private float _mu;
|
||||
|
||||
/// <summary>
|
||||
/// Current value of Mu, updated based on the difference between target surprise and actual surprise
|
||||
/// </summary>
|
||||
public float Mu
|
||||
{
|
||||
get => _mu;
|
||||
set => _mu = value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The target cross-entropy (or surprise) value you want to achieve for the generated text.
|
||||
/// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text
|
||||
/// </summary>
|
||||
public float Tau { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word.
|
||||
/// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
/// </summary>
|
||||
public float Eta { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn
|
||||
/// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects
|
||||
/// the performance of the algorithm.
|
||||
/// </summary>
|
||||
public int M { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Create a new Mirostat 2.0 sampler
|
||||
/// </summary>
|
||||
/// <param name="tau">The target cross-entropy (or surprise) value you want to achieve for the generated text.
|
||||
/// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text</param>
|
||||
/// <param name="eta">The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word.
|
||||
/// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.</param>
|
||||
/// <param name="m">The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn
|
||||
/// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects
|
||||
/// the performance of the algorithm.</param>
|
||||
public MirostatSelection(float tau, float eta, int m = 100)
|
||||
{
|
||||
Tau = tau;
|
||||
Eta = eta;
|
||||
M = m;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
return candidates.SampleTokenMirostat(ctx, Tau, Eta, M, ref _mu);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
_mu = 2 * Tau;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Selection;
|
||||
|
||||
/// <summary>
|
||||
/// Select from all possible tokens according to their probability
|
||||
/// </summary>
|
||||
public sealed class StandardSelection
|
||||
: ITokenSelector
|
||||
{
|
||||
/// <inheritdoc />
|
||||
public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
return candidates.SampleToken(ctx);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
using System;
|
||||
using LLama.Grammars;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Apply a grammar to prevent sampling tokens which do not match the grammar
|
||||
/// </summary>
|
||||
public sealed class GrammarSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
private SafeLLamaGrammarHandle? _handle;
|
||||
|
||||
/// <summary>
|
||||
/// Grammar to use for sampling
|
||||
/// </summary>
|
||||
public Grammar? Grammar { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Create a new
|
||||
/// </summary>
|
||||
/// <param name="grammar"></param>
|
||||
public GrammarSampling(Grammar grammar)
|
||||
{
|
||||
Grammar = grammar;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
_handle?.Dispose();
|
||||
_handle = null;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
// Create a new grammar instance if necessary
|
||||
_handle ??= Grammar?.CreateInstance();
|
||||
|
||||
// Apply it
|
||||
if (_handle != null)
|
||||
tokens.ApplyGrammar(ctx, _handle);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
_handle?.AcceptToken(ctx, token);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
_handle?.Dispose();
|
||||
_handle = null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
using llama_token = Int32;
|
||||
|
||||
/// <summary>
|
||||
/// Processes token logits before sampling, applying penalties to certain tokens
|
||||
/// </summary>
|
||||
public interface ITokenDataProcessor
|
||||
: IDisposable
|
||||
{
|
||||
/// <summary>
|
||||
/// Process token logits in a LLamaTokenDataArray
|
||||
/// </summary>
|
||||
/// <param name="ctx">The context this is operating in</param>
|
||||
/// <param name="tokens">The token data array to process</param>
|
||||
/// <param name="lastTokens">The most recent tokens output</param>
|
||||
/// <returns>LLamaTokenDataArray, created from logits</returns>
|
||||
void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<llama_token> lastTokens);
|
||||
|
||||
/// <summary>
|
||||
/// Inform this process when a token is accepted by the model
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="token"></param>
|
||||
void AcceptToken(SafeLLamaContextHandle ctx, int token);
|
||||
|
||||
/// <summary>
|
||||
/// Reset all internal sampling state
|
||||
/// </summary>
|
||||
void Reset();
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
/// </summary>
|
||||
public sealed class LocallyTypicalSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// P value for locally typical sampling
|
||||
/// </summary>
|
||||
public float P { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Minimum number of tokens to keep
|
||||
/// </summary>
|
||||
public ulong MinKeep { get; set; } = 1;
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.LocallyTypical(ctx, P, MinKeep);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
||||
/// </summary>
|
||||
public sealed class MinPSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// All tokens with probability greater than this will be kept
|
||||
/// </summary>
|
||||
public float P { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Minimum number of tokens to keep
|
||||
/// </summary>
|
||||
public ulong MinKeep { get; set; } = 1;
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.MinP(ctx, P, MinKeep);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
/// </summary>
|
||||
public sealed class RepetitionPenalty
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
private float _alphaFreq;
|
||||
private float _alphaPresence;
|
||||
|
||||
/// <summary>
|
||||
/// Repetition penalty, as described in https://arxiv.org/abs/1909.05858
|
||||
/// </summary>
|
||||
public float RepeatPenalty { get; set; } = 1.1f;
|
||||
|
||||
/// <summary>
|
||||
/// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text
|
||||
/// so far, decreasing the model's likelihood to repeat the same line verbatim.
|
||||
/// </summary>
|
||||
public float AlphaFrequency
|
||||
{
|
||||
get => _alphaFreq;
|
||||
set
|
||||
{
|
||||
if (value < -2)
|
||||
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
|
||||
if (value > 2)
|
||||
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
|
||||
_alphaFreq = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create<br />
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
|
||||
/// text so far, increasing the model's likelihood to talk about new topics.
|
||||
/// </summary>
|
||||
public float AlphaPresence
|
||||
{
|
||||
get => _alphaPresence;
|
||||
set
|
||||
{
|
||||
if (value < -2)
|
||||
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2");
|
||||
if (value > 2)
|
||||
throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2");
|
||||
_alphaPresence = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
/// </summary>
|
||||
public sealed class TailFreeSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// Z value for tail free sampling
|
||||
/// </summary>
|
||||
public float Z { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Minimum number of tokens to keep
|
||||
/// </summary>
|
||||
public ulong MinKeep { get; set; } = 1;
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.TailFree(ctx, Z, MinKeep);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Sample with temperature.
|
||||
/// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
|
||||
/// </summary>
|
||||
public sealed class TemperatureSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// Temperature value to apply
|
||||
/// </summary>
|
||||
public float Temperature { get; set; } = 0.5f;
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.Temperature(ctx, Temperature);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Sample with TopK, removing all by the K most likely tokens.
|
||||
/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
/// </summary>
|
||||
public sealed class TopKSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// Number of tokens to keep
|
||||
/// </summary>
|
||||
public int Count { get; set; }
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.TopK(ctx, Count);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Sampling.Tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
/// </summary>
|
||||
public sealed class TopPSampling
|
||||
: ITokenDataProcessor
|
||||
{
|
||||
/// <summary>
|
||||
/// P valies for TopP
|
||||
/// </summary>
|
||||
public float P { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Minimum number of tokens to keep
|
||||
/// </summary>
|
||||
public ulong MinKeep { get; set; } = 1;
|
||||
|
||||
/// <inheritdoc />
|
||||
public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan<int> lastTokens)
|
||||
{
|
||||
tokens.TopP(ctx, P, MinKeep);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void AcceptToken(SafeLLamaContextHandle ctx, int token)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Reset()
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue