Removed `LLamaBatchSafeHandle` (using unmanaged memory, created by llama.cpp) and replaced it with a fully managed `LLamaBatch`. Modified the `BatchedDecoding` example to use new managed batch.
This commit is contained in:
parent
4b11feddef
commit
36a9335588
|
@ -52,18 +52,11 @@ public class BatchedDecoding
|
|||
return;
|
||||
}
|
||||
|
||||
using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1);
|
||||
var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1);
|
||||
|
||||
// evaluate the initial prompt
|
||||
for (var i = 0; i < prompt_tokens.Length; i++)
|
||||
batch.LLamaBatchAdd(prompt_tokens[i], i, new[] { (LLamaSeqId)0 }, false);
|
||||
Debug.Assert(batch.NativeBatch.n_tokens == prompt_tokens.Length);
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
unsafe
|
||||
{
|
||||
batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1;
|
||||
}
|
||||
batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);
|
||||
|
||||
if (context.NativeHandle.Decode(batch) != 0)
|
||||
{
|
||||
|
@ -75,7 +68,7 @@ public class BatchedDecoding
|
|||
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
|
||||
for (var i = 1; i < n_parallel; ++i)
|
||||
{
|
||||
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens);
|
||||
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
|
||||
}
|
||||
|
||||
if (n_parallel > 1)
|
||||
|
@ -88,9 +81,9 @@ public class BatchedDecoding
|
|||
// we need this to determine which logits to sample from
|
||||
List<int> i_batch = new();
|
||||
for (var i = 0; i < n_parallel; i++)
|
||||
i_batch.Add(batch.NativeBatch.n_tokens - 1);
|
||||
i_batch.Add(batch.TokenCount - 1);
|
||||
|
||||
var n_cur = batch.NativeBatch.n_tokens;
|
||||
var n_cur = batch.TokenCount;
|
||||
var n_decode = 0;
|
||||
|
||||
var streams = new List<LLamaToken>[n_parallel];
|
||||
|
@ -133,7 +126,7 @@ public class BatchedDecoding
|
|||
|
||||
streams[i].Add(new_token_id);
|
||||
|
||||
i_batch[i] = batch.NativeBatch.n_tokens;
|
||||
i_batch[i] = batch.TokenCount;
|
||||
|
||||
// push this new token for next evaluation
|
||||
batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
|
||||
|
@ -142,7 +135,7 @@ public class BatchedDecoding
|
|||
}
|
||||
|
||||
// all streams are finished
|
||||
if (batch.NativeBatch.n_tokens == 0)
|
||||
if (batch.TokenCount == 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
using System;
|
||||
|
||||
namespace LLama.Native;
|
||||
|
||||
/// <summary>
|
||||
/// A batch allows submitting multiple tokens to multiple sequences simultaneously
|
||||
/// </summary>
|
||||
public class LLamaBatch
|
||||
{
|
||||
private readonly byte[] _logits;
|
||||
|
||||
private readonly LLamaToken[] _tokens;
|
||||
private readonly LLamaPos[] _positions;
|
||||
|
||||
private readonly int[] _sequenceIdCount;
|
||||
private readonly LLamaSeqId[][] _sequenceIds;
|
||||
private readonly IntPtr[] _sequenceIdsPtrs;
|
||||
|
||||
/// <summary>
|
||||
/// The number of tokens in this batch
|
||||
/// </summary>
|
||||
public int TokenCount { get; private set; }
|
||||
|
||||
/// <summary>
|
||||
/// Create a new batch for submitting inputs to llama.cpp
|
||||
/// </summary>
|
||||
/// <param name="n_tokens"></param>
|
||||
/// <param name="n_seq_max"></param>
|
||||
public LLamaBatch(int n_tokens, int n_seq_max)
|
||||
{
|
||||
_logits = new byte[n_tokens];
|
||||
_tokens = new LLamaToken[n_tokens];
|
||||
_positions = new LLamaPos[n_tokens];
|
||||
|
||||
_sequenceIdCount = new int[n_tokens];
|
||||
_sequenceIdsPtrs = new IntPtr[_sequenceIdCount.Length];
|
||||
|
||||
_sequenceIds = new LLamaSeqId[n_tokens][];
|
||||
for (var i = 0; i < _sequenceIds.Length; i++)
|
||||
_sequenceIds[i] = new LLamaSeqId[n_seq_max];
|
||||
}
|
||||
|
||||
internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
|
||||
{
|
||||
// This group holds all of the memory pins
|
||||
var group = new GroupDisposable();
|
||||
|
||||
unsafe
|
||||
{
|
||||
batch = new LLamaNativeBatch
|
||||
{
|
||||
n_tokens = TokenCount,
|
||||
logits = (byte*)group.Add(_logits.AsMemory().Pin()).Pointer,
|
||||
|
||||
n_seq_id = (int*)group.Add(_sequenceIdCount.AsMemory().Pin()).Pointer,
|
||||
pos = (LLamaPos*)group.Add(_positions.AsMemory().Pin()).Pointer,
|
||||
seq_id = (LLamaSeqId**)group.Add(_sequenceIdsPtrs.AsMemory().Pin()).Pointer,
|
||||
|
||||
// embd is not currently supported, so this is always null!
|
||||
embd = null,
|
||||
|
||||
// Note that if embd is **not null** then this will be null!
|
||||
tokens = (LLamaToken*)group.Add(_tokens.AsMemory().Pin()).Pointer,
|
||||
};
|
||||
|
||||
// Create pointers to each of the arrays in turns
|
||||
for (var i = 0; i < _sequenceIdsPtrs.Length; i++)
|
||||
_sequenceIdsPtrs[i] = (IntPtr)group.Add(_sequenceIds[i].AsMemory().Pin()).Pointer;
|
||||
}
|
||||
|
||||
return group;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add a single token to the batch at the same position in several sequences
|
||||
/// </summary>
|
||||
/// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks>
|
||||
/// <param name="token">The token to add</param>
|
||||
/// <param name="pos">The position to add it att</param>
|
||||
/// <param name="sequences">The set of sequences to add this token to</param>
|
||||
/// <param name="logits"></param>
|
||||
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
|
||||
{
|
||||
_tokens[TokenCount] = token;
|
||||
_positions[TokenCount] = pos;
|
||||
|
||||
_sequenceIdCount[TokenCount] = sequences.Length;
|
||||
for (var i = 0; i < sequences.Length; i++)
|
||||
_sequenceIds[TokenCount][i] = sequences[i];
|
||||
|
||||
_logits[TokenCount] = Convert.ToByte(logits);
|
||||
|
||||
TokenCount++;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add a single token to the batch at a certain position for a single sequences
|
||||
/// </summary>
|
||||
/// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks>
|
||||
/// <param name="token">The token to add</param>
|
||||
/// <param name="pos">The position to add it att</param>
|
||||
/// <param name="sequence">The sequence to add this token to</param>
|
||||
/// <param name="logits"></param>
|
||||
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
|
||||
{
|
||||
// Create a temporary span to contain 1 item without allocating
|
||||
Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1];
|
||||
sequences[0] = sequence;
|
||||
|
||||
// Add it
|
||||
LLamaBatchAdd(token, pos, sequences, logits);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Set TokenCount to zero for this batch
|
||||
/// </summary>
|
||||
public void LLamaBatchClear()
|
||||
{
|
||||
TokenCount = 0;
|
||||
}
|
||||
}
|
|
@ -1,158 +0,0 @@
|
|||
using System;
|
||||
|
||||
namespace LLama.Native;
|
||||
|
||||
/// <summary>
|
||||
/// Input data for llama_decode. A llama_batch object can contain input about one or many sequences.
|
||||
/// </summary>
|
||||
public sealed class LLamaBatchSafeHandle
|
||||
: SafeLLamaHandleBase
|
||||
{
|
||||
private readonly int _embd;
|
||||
|
||||
/// <summary>
|
||||
/// Get the native llama_batch struct
|
||||
/// </summary>
|
||||
public LLamaNativeBatch NativeBatch;
|
||||
|
||||
/// <summary>
|
||||
/// the token ids of the input (used when embd is NULL)
|
||||
/// </summary>
|
||||
public Span<LLamaToken> Token
|
||||
{
|
||||
get
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
if (_embd != 0)
|
||||
return new Span<LLamaToken>(null, 0);
|
||||
else
|
||||
return new Span<LLamaToken>(NativeBatch.token, NativeBatch.n_tokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||
/// </summary>
|
||||
public Span<LLamaToken> Embed
|
||||
{
|
||||
get
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float)
|
||||
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
|
||||
|
||||
if (_embd != 0)
|
||||
return new Span<LLamaToken>(NativeBatch.embd, NativeBatch.n_tokens * _embd);
|
||||
else
|
||||
return new Span<LLamaToken>(null, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the positions of the respective token in the sequence
|
||||
/// </summary>
|
||||
public Span<LLamaPos> Pos
|
||||
{
|
||||
get
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
return new Span<LLamaPos>(NativeBatch.pos, NativeBatch.n_tokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// the sequence to which the respective token belongs
|
||||
/// </summary>
|
||||
public Span<LLamaSeqId> Sequence_ID
|
||||
{
|
||||
get
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
return new Span<LLamaSeqId>(NativeBatch.seq_id, NativeBatch.n_tokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// if zero, the logits for the respective token will not be output
|
||||
/// </summary>
|
||||
public Span<byte> Logits
|
||||
{
|
||||
get
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
return new Span<byte>(NativeBatch.logits, NativeBatch.n_tokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a safe handle owning a `LLamaNativeBatch`
|
||||
/// </summary>
|
||||
/// <param name="batch"></param>
|
||||
/// <param name="embd"></param>
|
||||
public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd)
|
||||
: base((nint)1)
|
||||
{
|
||||
_embd = embd;
|
||||
NativeBatch = batch;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Call `llama_batch_init` and create a new batch
|
||||
/// </summary>
|
||||
/// <param name="n_tokens"></param>
|
||||
/// <param name="embd"></param>
|
||||
/// <param name="n_seq_max"></param>
|
||||
/// <returns></returns>
|
||||
public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max)
|
||||
{
|
||||
var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max);
|
||||
return new LLamaBatchSafeHandle(batch, embd);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override bool ReleaseHandle()
|
||||
{
|
||||
NativeApi.llama_batch_free(NativeBatch);
|
||||
NativeBatch = default;
|
||||
SetHandle(IntPtr.Zero);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
|
||||
/// </summary>
|
||||
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
NativeBatch.token[NativeBatch.n_tokens] = token;
|
||||
NativeBatch.pos[NativeBatch.n_tokens] = pos;
|
||||
NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length;
|
||||
|
||||
for (var i = 0; i < sequences.Length; i++)
|
||||
NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i];
|
||||
|
||||
NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits);
|
||||
|
||||
NativeBatch.n_tokens++;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825
|
||||
/// </summary>
|
||||
public void LLamaBatchClear()
|
||||
{
|
||||
NativeBatch.n_tokens = 0;
|
||||
}
|
||||
}
|
|
@ -18,7 +18,7 @@ public unsafe struct LLamaNativeBatch
|
|||
/// <summary>
|
||||
/// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created
|
||||
/// </summary>
|
||||
public LLamaToken* token;
|
||||
public LLamaToken* tokens;
|
||||
|
||||
/// <summary>
|
||||
/// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created
|
||||
|
|
|
@ -8,6 +8,11 @@ namespace LLama.Native;
|
|||
[StructLayout(LayoutKind.Sequential)]
|
||||
public record struct LLamaSeqId
|
||||
{
|
||||
/// <summary>
|
||||
/// LLamaSeqId with value 0
|
||||
/// </summary>
|
||||
public static readonly LLamaSeqId Zero = new LLamaSeqId(0);
|
||||
|
||||
/// <summary>
|
||||
/// The raw value
|
||||
/// </summary>
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
using System;
|
||||
using System.Buffers;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
using LLama.Exceptions;
|
||||
|
@ -198,9 +197,10 @@ namespace LLama.Native
|
|||
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
|
||||
/// - < 0: error<br />
|
||||
/// </returns>
|
||||
public int Decode(LLamaBatchSafeHandle batch)
|
||||
public int Decode(LLamaBatch batch)
|
||||
{
|
||||
return NativeApi.llama_decode(this, batch.NativeBatch);
|
||||
using (batch.ToNativeBatch(out var nb))
|
||||
return NativeApi.llama_decode(this, nb);
|
||||
}
|
||||
|
||||
#region state
|
||||
|
|
Loading…
Reference in New Issue