Swapped `StatelessExecutor` to use `llama_decode`!

- Added `logits_i` argument to `Context.ApplyPenalty`
 - Added a new exception type for `llama_decode` return code
This commit is contained in:
Martin Evans 2024-01-20 21:18:35 +00:00
parent 892e841da3
commit a2e29d393c
8 changed files with 90 additions and 38 deletions

View File

@ -105,12 +105,7 @@ public class BatchedDecoding
if (i_batch[i] < 0)
continue;
var n_vocab = model.VocabCount;
LLamaTokenDataArray candidates;
unsafe
{
candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
}
var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i]));
candidates.TopK(context.NativeHandle, top_k);
candidates.TopP(context.NativeHandle, top_p);

View File

@ -19,6 +19,7 @@ namespace LLama.Unittest
{
ContextSize = 60,
Seed = 1754,
BatchSize = 2,
};
_weights = LLamaWeights.LoadFromFile(_params);
}
@ -60,7 +61,7 @@ namespace LLama.Unittest
{
var executor = new StatelessExecutor(_weights, _params);
const string question = " Question. cats or dogs?\nAnswer: ";
const string question = " Question. cats or dogs?\nAnswer:";
// The context size is set to 60. Generate more than that, forcing it to generate a coherent response
// with a modified context

View File

@ -1,4 +1,5 @@
using System;
using LLama.Native;
namespace LLama.Exceptions;
@ -36,4 +37,23 @@ public class LoadWeightsFailedException
{
ModelPath = modelPath;
}
}
/// <summary>
/// `llama_decode` return a non-zero status code
/// </summary>
public class LLamaDecodeError
: RuntimeError
{
/// <summary>
/// The return status code
/// </summary>
public DecodeResult ReturnCode { get; }
/// <inheritdoc />
public LLamaDecodeError(DecodeResult returnCode)
: base($"llama_decode failed: '{returnCode}'")
{
ReturnCode = returnCode;
}
}

View File

@ -293,6 +293,7 @@ namespace LLama
/// <summary>
/// Apply the penalty for the tokens. Please don't use it unless you fully know what it does.
/// </summary>
/// <param name="logits_i"></param>
/// <param name="lastTokens"></param>
/// <param name="logitBias"></param>
/// <param name="repeatLastTokensCount"></param>
@ -301,11 +302,11 @@ namespace LLama
/// <param name="alphaPresence"></param>
/// <param name="penalizeNL"></param>
/// <returns></returns>
public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
var logits = NativeHandle.GetLogits();
var logits = NativeHandle.GetLogitsIth(logits_i);
// Apply params.logit_bias map
if (logitBias is not null)
@ -348,28 +349,23 @@ namespace LLama
/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatch batch)
public DecodeResult Decode(LLamaBatch batch)
{
return NativeHandle.Decode(batch);
if (batch.TokenCount == 0)
return 0;
if (batch.TokenCount > Params.BatchSize)
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));
return (DecodeResult)NativeHandle.Decode(batch);
}
/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <param name="cancellationToken"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
{
return Task.Run(() => NativeHandle.Decode(batch), cancellationToken);
return Task.Run(() => Decode(batch), cancellationToken);
}
/// <summary>

View File

@ -216,7 +216,7 @@ namespace LLama
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
var mu = MirostatMu;

View File

@ -195,7 +195,7 @@ namespace LLama
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
var mu = MirostatMu;

View File

@ -5,7 +5,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Exceptions;
using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
@ -22,6 +22,7 @@ namespace LLama
private readonly LLamaWeights _weights;
private readonly IContextParams _params;
private readonly ILogger? _logger;
private readonly LLamaBatch _batch;
/// <summary>
/// The context used by the executor when running the inference.
@ -39,6 +40,7 @@ namespace LLama
_weights = weights;
_params = @params;
_logger = logger;
_batch = new LLamaBatch(1);
Context = _weights.CreateContext(_params, logger);
Context.Dispose();
@ -71,16 +73,29 @@ namespace LLama
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
var lastTokens = new List<LLamaToken>(repeat_last_n);
for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add((LLamaToken)0);
lastTokens.Add(0);
// Tokenize the prompt
var tokens = Context.Tokenize(prompt).ToList();
lastTokens.AddRange(tokens);
var n_past = 1 + tokens.Count;
// Evaluate the prompt
await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
.ConfigureAwait(false);
// Evaluate the prompt, in chunks smaller than the max batch size
var n_past = 0;
var batchSize = (int)Context.Params.BatchSize;
for (var i = 0; i < tokens.Count; i += batchSize)
{
var n_eval = tokens.Count - i;
if (n_eval > batchSize)
n_eval = batchSize;
_batch.Clear();
for (var j = 0; j < n_eval; j++)
_batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1);
var returnCode = await Context.DecodeAsync(_batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}
// Begin loop, evaluating one token at a time
var mu = (float?)null;
@ -90,12 +105,12 @@ namespace LLama
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(_batch.TokenCount - 1, lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
// Sample a single token
@ -136,9 +151,12 @@ namespace LLama
n_past -= n_discard;
}
// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
.ConfigureAwait(false);
// Evaluate with this new token
_batch.Clear();
_batch.Add(id, n_past++, LLamaSeqId.Zero, true);
var returnCode = await context.DecodeAsync(_batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}
}
}

View File

@ -0,0 +1,22 @@
namespace LLama.Native;
/// <summary>
/// Return codes from llama_decode
/// </summary>
public enum DecodeResult
{
/// <summary>
/// An unspecified error
/// </summary>
Error = -1,
/// <summary>
/// Ok.
/// </summary>
Ok = 0,
/// <summary>
/// Could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
/// </summary>
NoKvSlot = 1,
}