Swapped `StatelessExecutor` to use `llama_decode`!
- Added `logits_i` argument to `Context.ApplyPenalty` - Added a new exception type for `llama_decode` return code
This commit is contained in:
parent
892e841da3
commit
a2e29d393c
|
@ -105,12 +105,7 @@ public class BatchedDecoding
|
|||
if (i_batch[i] < 0)
|
||||
continue;
|
||||
|
||||
var n_vocab = model.VocabCount;
|
||||
LLamaTokenDataArray candidates;
|
||||
unsafe
|
||||
{
|
||||
candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
|
||||
}
|
||||
var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i]));
|
||||
|
||||
candidates.TopK(context.NativeHandle, top_k);
|
||||
candidates.TopP(context.NativeHandle, top_p);
|
||||
|
|
|
@ -19,6 +19,7 @@ namespace LLama.Unittest
|
|||
{
|
||||
ContextSize = 60,
|
||||
Seed = 1754,
|
||||
BatchSize = 2,
|
||||
};
|
||||
_weights = LLamaWeights.LoadFromFile(_params);
|
||||
}
|
||||
|
@ -60,7 +61,7 @@ namespace LLama.Unittest
|
|||
{
|
||||
var executor = new StatelessExecutor(_weights, _params);
|
||||
|
||||
const string question = " Question. cats or dogs?\nAnswer: ";
|
||||
const string question = " Question. cats or dogs?\nAnswer:";
|
||||
|
||||
// The context size is set to 60. Generate more than that, forcing it to generate a coherent response
|
||||
// with a modified context
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
using System;
|
||||
using LLama.Native;
|
||||
|
||||
namespace LLama.Exceptions;
|
||||
|
||||
|
@ -36,4 +37,23 @@ public class LoadWeightsFailedException
|
|||
{
|
||||
ModelPath = modelPath;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// `llama_decode` return a non-zero status code
|
||||
/// </summary>
|
||||
public class LLamaDecodeError
|
||||
: RuntimeError
|
||||
{
|
||||
/// <summary>
|
||||
/// The return status code
|
||||
/// </summary>
|
||||
public DecodeResult ReturnCode { get; }
|
||||
|
||||
/// <inheritdoc />
|
||||
public LLamaDecodeError(DecodeResult returnCode)
|
||||
: base($"llama_decode failed: '{returnCode}'")
|
||||
{
|
||||
ReturnCode = returnCode;
|
||||
}
|
||||
}
|
|
@ -293,6 +293,7 @@ namespace LLama
|
|||
/// <summary>
|
||||
/// Apply the penalty for the tokens. Please don't use it unless you fully know what it does.
|
||||
/// </summary>
|
||||
/// <param name="logits_i"></param>
|
||||
/// <param name="lastTokens"></param>
|
||||
/// <param name="logitBias"></param>
|
||||
/// <param name="repeatLastTokensCount"></param>
|
||||
|
@ -301,11 +302,11 @@ namespace LLama
|
|||
/// <param name="alphaPresence"></param>
|
||||
/// <param name="penalizeNL"></param>
|
||||
/// <returns></returns>
|
||||
public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
|
||||
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
|
||||
bool penalizeNL = true)
|
||||
public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
|
||||
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
|
||||
bool penalizeNL = true)
|
||||
{
|
||||
var logits = NativeHandle.GetLogits();
|
||||
var logits = NativeHandle.GetLogitsIth(logits_i);
|
||||
|
||||
// Apply params.logit_bias map
|
||||
if (logitBias is not null)
|
||||
|
@ -348,28 +349,23 @@ namespace LLama
|
|||
/// <summary>
|
||||
/// </summary>
|
||||
/// <param name="batch"></param>
|
||||
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
|
||||
/// - 0: success<br />
|
||||
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
|
||||
/// - < 0: error<br />
|
||||
/// </returns>
|
||||
public int Decode(LLamaBatch batch)
|
||||
public DecodeResult Decode(LLamaBatch batch)
|
||||
{
|
||||
return NativeHandle.Decode(batch);
|
||||
if (batch.TokenCount == 0)
|
||||
return 0;
|
||||
if (batch.TokenCount > Params.BatchSize)
|
||||
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));
|
||||
|
||||
return (DecodeResult)NativeHandle.Decode(batch);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// </summary>
|
||||
/// <param name="batch"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
|
||||
/// - 0: success<br />
|
||||
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
|
||||
/// - < 0: error<br />
|
||||
/// </returns>
|
||||
public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
|
||||
public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return Task.Run(() => NativeHandle.Decode(batch), cancellationToken);
|
||||
return Task.Run(() => Decode(batch), cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -216,7 +216,7 @@ namespace LLama
|
|||
}
|
||||
else
|
||||
{
|
||||
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
|
||||
var mu = MirostatMu;
|
||||
|
|
|
@ -195,7 +195,7 @@ namespace LLama
|
|||
}
|
||||
else
|
||||
{
|
||||
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
|
||||
var mu = MirostatMu;
|
||||
|
|
|
@ -5,7 +5,7 @@ using System.Collections.Generic;
|
|||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using LLama.Exceptions;
|
||||
using LLama.Native;
|
||||
using LLama.Sampling;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
@ -22,6 +22,7 @@ namespace LLama
|
|||
private readonly LLamaWeights _weights;
|
||||
private readonly IContextParams _params;
|
||||
private readonly ILogger? _logger;
|
||||
private readonly LLamaBatch _batch;
|
||||
|
||||
/// <summary>
|
||||
/// The context used by the executor when running the inference.
|
||||
|
@ -39,6 +40,7 @@ namespace LLama
|
|||
_weights = weights;
|
||||
_params = @params;
|
||||
_logger = logger;
|
||||
_batch = new LLamaBatch(1);
|
||||
|
||||
Context = _weights.CreateContext(_params, logger);
|
||||
Context.Dispose();
|
||||
|
@ -71,16 +73,29 @@ namespace LLama
|
|||
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
|
||||
var lastTokens = new List<LLamaToken>(repeat_last_n);
|
||||
for (var i = 0; i < repeat_last_n; i++)
|
||||
lastTokens.Add((LLamaToken)0);
|
||||
lastTokens.Add(0);
|
||||
|
||||
// Tokenize the prompt
|
||||
var tokens = Context.Tokenize(prompt).ToList();
|
||||
lastTokens.AddRange(tokens);
|
||||
var n_past = 1 + tokens.Count;
|
||||
|
||||
// Evaluate the prompt
|
||||
await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
// Evaluate the prompt, in chunks smaller than the max batch size
|
||||
var n_past = 0;
|
||||
var batchSize = (int)Context.Params.BatchSize;
|
||||
for (var i = 0; i < tokens.Count; i += batchSize)
|
||||
{
|
||||
var n_eval = tokens.Count - i;
|
||||
if (n_eval > batchSize)
|
||||
n_eval = batchSize;
|
||||
|
||||
_batch.Clear();
|
||||
for (var j = 0; j < n_eval; j++)
|
||||
_batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1);
|
||||
|
||||
var returnCode = await Context.DecodeAsync(_batch, cancellationToken);
|
||||
if (returnCode != 0)
|
||||
throw new LLamaDecodeError(returnCode);
|
||||
}
|
||||
|
||||
// Begin loop, evaluating one token at a time
|
||||
var mu = (float?)null;
|
||||
|
@ -90,12 +105,12 @@ namespace LLama
|
|||
LLamaToken id;
|
||||
if (inferenceParams.SamplingPipeline is not null)
|
||||
{
|
||||
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
|
||||
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Penalize the generated tokens by various penalties
|
||||
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
var tokenDataArray = Context.ApplyPenalty(_batch.TokenCount - 1, lastTokens, inferenceParams.LogitBias, repeat_last_n,
|
||||
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
||||
|
||||
// Sample a single token
|
||||
|
@ -136,9 +151,12 @@ namespace LLama
|
|||
n_past -= n_discard;
|
||||
}
|
||||
|
||||
// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
|
||||
n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
// Evaluate with this new token
|
||||
_batch.Clear();
|
||||
_batch.Add(id, n_past++, LLamaSeqId.Zero, true);
|
||||
var returnCode = await context.DecodeAsync(_batch, cancellationToken);
|
||||
if (returnCode != 0)
|
||||
throw new LLamaDecodeError(returnCode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
namespace LLama.Native;
|
||||
|
||||
/// <summary>
|
||||
/// Return codes from llama_decode
|
||||
/// </summary>
|
||||
public enum DecodeResult
|
||||
{
|
||||
/// <summary>
|
||||
/// An unspecified error
|
||||
/// </summary>
|
||||
Error = -1,
|
||||
|
||||
/// <summary>
|
||||
/// Ok.
|
||||
/// </summary>
|
||||
Ok = 0,
|
||||
|
||||
/// <summary>
|
||||
/// Could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
/// </summary>
|
||||
NoKvSlot = 1,
|
||||
}
|
Loading…
Reference in New Issue