65 lines
2.2 KiB
C#
65 lines
2.2 KiB
C#
using System;
|
|
using LLama.Native;
|
|
|
|
namespace LLama.Sampling;
|
|
|
|
/// <summary>
|
|
/// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`.
|
|
/// </summary>
|
|
public abstract class BaseSamplingPipeline
|
|
: ISamplingPipeline
|
|
{
|
|
/// <summary>
|
|
/// Grammar to constrain valid tokens
|
|
/// </summary>
|
|
public SafeLLamaGrammarHandle? Grammar { get; set; }
|
|
|
|
/// <inheritdoc/>
|
|
public LLamaToken Sample(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
|
|
{
|
|
// Apply processing to raw logit values
|
|
ProcessLogits(ctx, logits, lastTokens);
|
|
|
|
// Process token data array to select a final token
|
|
var candidates = LLamaTokenDataArray.Create(logits);
|
|
candidates.ApplyGrammar(ctx, Grammar);
|
|
return ProcessTokenDataArray(ctx, candidates, lastTokens);
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public virtual void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
|
|
{
|
|
Grammar?.AcceptToken(ctx, token);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Process the raw logit values
|
|
/// </summary>
|
|
/// <param name="ctx">The context being sampled from</param>
|
|
/// <param name="logits">The logits produced by the model</param>
|
|
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
|
|
protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens);
|
|
|
|
/// <summary>
|
|
/// Process the LLamaTokenDataArray and select a single token
|
|
/// </summary>
|
|
/// <param name="ctx">The context being sampled from</param>
|
|
/// <param name="candidates">The LLamaTokenDataArray data produced by the model</param>
|
|
/// <param name="lastTokens">A list of tokens recently returned by the model</param>
|
|
/// <returns></returns>
|
|
protected abstract LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens);
|
|
|
|
/// <inheritdoc/>
|
|
public virtual void Reset()
|
|
{
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public abstract ISamplingPipeline Clone();
|
|
|
|
/// <inheritdoc/>
|
|
public virtual void Dispose()
|
|
{
|
|
GC.SuppressFinalize(this);
|
|
}
|
|
} |