LLamaSharp/LLama/Sampling/MirostatSamplingPipeline.cs

71 lines
1.6 KiB
C#

using System;
using LLama.Native;
namespace LLama.Sampling;
/// <summary>
/// A sampling pipeline which uses mirostat (v1) to select tokens
/// </summary>
public class MirostateSamplingPipeline
: BaseSamplingPipeline
{
private const int MIROSTAT_M = 100;
private const float DEFAULT_TAU = 5;
private float _mu = DEFAULT_TAU * 2;
/// <summary>
/// Currently learned mu value
/// </summary>
public float Mu => _mu;
private float _tau = DEFAULT_TAU;
/// <summary>
/// target entropy
/// </summary>
public float Tau
{
get => _tau;
set
{
_tau = value;
_mu = value * 2;
}
}
/// <summary>
/// learning rate
/// </summary>
public float Eta { get; set; } = 0.1f;
/// <inheritdoc />
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
}
/// <inheritdoc />
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
{
return candidates.SampleTokenMirostat(ctx, Tau, Eta, MIROSTAT_M, ref _mu);
}
/// <inheritdoc />
public override void Reset()
{
base.Reset();
_mu = Tau * 2;
}
/// <inheritdoc />
public override ISamplingPipeline Clone()
{
return new MirostateSamplingPipeline
{
Grammar = Grammar?.Clone(),
_mu = _mu,
_tau = _tau,
Eta = Eta
};
}
}