Merge pull request #277 from martindevans/feature/min_p

MinP Sampler
This commit is contained in:
Martin Evans 2023-11-13 02:15:52 +00:00 committed by GitHub
commit e3468d04f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 108 additions and 131 deletions

View File

@ -4,93 +4,61 @@ using LLama.Native;
namespace LLama.Web.Common
{
public class InferenceOptions : IInferenceParams
public class InferenceOptions
: IInferenceParams
{
/// <summary>
/// number of tokens to keep from initial prompt
/// </summary>
/// <inheritdoc />
public int TokensKeep { get; set; } = 0;
/// <summary>
/// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
/// until it complete.
/// </summary>
/// <inheritdoc />
public int MaxTokens { get; set; } = -1;
/// <summary>
/// logit bias for specific tokens
/// </summary>
/// <inheritdoc />
public Dictionary<int, float>? LogitBias { get; set; } = null;
/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
/// <inheritdoc />
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();
/// <summary>
/// path to file for saving/loading model eval state
/// </summary>
public string PathSession { get; set; } = string.Empty;
/// <summary>
/// string to suffix user inputs with
/// </summary>
public string InputSuffix { get; set; } = string.Empty;
/// <summary>
/// string to prefix user inputs with
/// </summary>
public string InputPrefix { get; set; } = string.Empty;
/// <summary>
/// 0 or lower to use vocab size
/// </summary>
/// <inheritdoc />
public int TopK { get; set; } = 40;
/// <summary>
/// 1.0 = disabled
/// </summary>
/// <inheritdoc />
public float TopP { get; set; } = 0.95f;
/// <summary>
/// 1.0 = disabled
/// </summary>
/// <inheritdoc />
public float MinP { get; set; } = 0.05f;
/// <inheritdoc />
public float TfsZ { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>
/// <inheritdoc />
public float TypicalP { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>
/// <inheritdoc />
public float Temperature { get; set; } = 0.8f;
/// <summary>
/// 1.0 = disabled
/// </summary>
/// <inheritdoc />
public float RepeatPenalty { get; set; } = 1.1f;
/// <summary>
/// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
/// </summary>
/// <inheritdoc />
public int RepeatLastTokensCount { get; set; } = 64;
/// <summary>
/// frequency penalty coefficient
/// 0.0 = disabled
/// </summary>
/// <inheritdoc />
public float FrequencyPenalty { get; set; } = .0f;
/// <summary>
/// presence penalty coefficient
/// 0.0 = disabled
/// </summary>
/// <inheritdoc />
public float PresencePenalty { get; set; } = .0f;
/// <summary>
/// Mirostat uses tokens instead of words.
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
/// </summary>
/// <inheritdoc />
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
/// <summary>
/// target entropy
/// </summary>
/// <inheritdoc />
public float MirostatTau { get; set; } = 5.0f;
/// <summary>
/// learning rate
/// </summary>
/// <inheritdoc />
public float MirostatEta { get; set; } = 0.1f;
/// <summary>
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>
/// <inheritdoc />
public bool PenalizeNL { get; set; } = true;
/// <summary>

View File

@ -25,7 +25,6 @@ namespace LLama.Abstractions
/// </summary>
public Dictionary<int, float>? LogitBias { get; set; }
/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
@ -41,10 +40,15 @@ namespace LLama.Abstractions
/// </summary>
public float TopP { get; set; }
/// <summary>
/// 1.0 = disabled
/// </summary>
public float TfsZ { get; set; }
/// <summary>llama_eval
/// 0.0 = disabled
/// </summary>
public float MinP { get; set; }
/// <summary>
/// 1.0 = disabled
/// </summary>
public float TfsZ { get; set; }
/// <summary>
/// 1.0 = disabled

View File

@ -6,10 +6,12 @@ using LLama.Native;
namespace LLama.Common
{
using llama_token = Int32;
/// <summary>
/// The paramters used for inference.
/// </summary>
public record InferenceParams : IInferenceParams
public record InferenceParams
: IInferenceParams
{
/// <summary>
/// number of tokens to keep from initial prompt
@ -30,66 +32,49 @@ namespace LLama.Common
/// </summary>
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();
/// <summary>
/// 0 or lower to use vocab size
/// </summary>
/// <inheritdoc />
public int TopK { get; set; } = 40;
/// <summary>
/// 1.0 = disabled
/// </summary>
/// <inheritdoc />
public float TopP { get; set; } = 0.95f;
/// <summary>
/// 1.0 = disabled
/// </summary>
/// <inheritdoc />
public float MinP { get; set; } = 0.05f;
/// <inheritdoc />
public float TfsZ { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>
/// <inheritdoc />
public float TypicalP { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>
/// <inheritdoc />
public float Temperature { get; set; } = 0.8f;
/// <summary>
/// 1.0 = disabled
/// </summary>
/// <inheritdoc />
public float RepeatPenalty { get; set; } = 1.1f;
/// <summary>
/// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
/// </summary>
/// <inheritdoc />
public int RepeatLastTokensCount { get; set; } = 64;
/// <summary>
/// frequency penalty coefficient
/// 0.0 = disabled
/// </summary>
/// <inheritdoc />
public float FrequencyPenalty { get; set; } = .0f;
/// <summary>
/// presence penalty coefficient
/// 0.0 = disabled
/// </summary>
/// <inheritdoc />
public float PresencePenalty { get; set; } = .0f;
/// <summary>
/// Mirostat uses tokens instead of words.
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
/// </summary>
/// <inheritdoc />
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
/// <summary>
/// target entropy
/// </summary>
/// <inheritdoc />
public float MirostatTau { get; set; } = 5.0f;
/// <summary>
/// learning rate
/// </summary>
/// <inheritdoc />
public float MirostatEta { get; set; } = 0.1f;
/// <summary>
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>
/// <inheritdoc />
public bool PenalizeNL { get; set; } = true;
/// <summary>
/// A grammar to constrain the possible tokens
/// </summary>
/// <inheritdoc />
public SafeLLamaGrammarHandle? Grammar { get; set; }
}

View File

@ -226,10 +226,11 @@ namespace LLama
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <param name="grammar"></param>
/// <param name="minP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f,
SafeLLamaGrammarHandle? grammar = null)
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
SafeLLamaGrammarHandle? grammar, float minP)
{
llama_token id;
@ -264,6 +265,7 @@ namespace LLama
candidates.TailFree(NativeHandle, tfsZ);
candidates.LocallyTypical(NativeHandle, typicalP);
candidates.TopP(NativeHandle, topP);
candidates.MinP(NativeHandle, minP);
candidates.Temperature(NativeHandle, temperature);
id = candidates.SampleToken(NativeHandle);
}

View File

@ -216,8 +216,8 @@ namespace LLama
var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
inferenceParams.Grammar
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;

View File

@ -194,9 +194,9 @@ namespace LLama
var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
inferenceParams.Grammar
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;

View File

@ -90,8 +90,11 @@ namespace LLama
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
// Sample a single token
var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
// Decode this token into text
decoder.Add(id);

View File

@ -91,6 +91,21 @@ namespace LLama.Native
}
}
/// <summary>
/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
/// </summary>
/// <param name="context"></param>
/// <param name="p">All tokens with probability greater than this will be kept</param>
/// <param name="minKeep"></param>
public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_min_p(context, ref st, p, minKeep);
sorted = st.sorted;
}
}
/// <summary>
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
/// </summary>