Removed `LLamaBatchSafeHandle` (using unmanaged memory, created by llama.cpp) and replaced it with a fully managed `LLamaBatch`. Modified the `BatchedDecoding` example to use new managed batch.

This commit is contained in:
Martin Evans 2024-01-19 23:26:36 +00:00
parent 4b11feddef
commit 36a9335588
6 changed files with 137 additions and 176 deletions

View File

@ -52,18 +52,11 @@ public class BatchedDecoding
return;
}
using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1);
var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1);
// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
batch.LLamaBatchAdd(prompt_tokens[i], i, new[] { (LLamaSeqId)0 }, false);
Debug.Assert(batch.NativeBatch.n_tokens == prompt_tokens.Length);
// llama_decode will output logits only for the last token of the prompt
unsafe
{
batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1;
}
batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);
if (context.NativeHandle.Decode(batch) != 0)
{
@ -75,7 +68,7 @@ public class BatchedDecoding
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (var i = 1; i < n_parallel; ++i)
{
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens);
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
}
if (n_parallel > 1)
@ -88,9 +81,9 @@ public class BatchedDecoding
// we need this to determine which logits to sample from
List<int> i_batch = new();
for (var i = 0; i < n_parallel; i++)
i_batch.Add(batch.NativeBatch.n_tokens - 1);
i_batch.Add(batch.TokenCount - 1);
var n_cur = batch.NativeBatch.n_tokens;
var n_cur = batch.TokenCount;
var n_decode = 0;
var streams = new List<LLamaToken>[n_parallel];
@ -133,7 +126,7 @@ public class BatchedDecoding
streams[i].Add(new_token_id);
i_batch[i] = batch.NativeBatch.n_tokens;
i_batch[i] = batch.TokenCount;
// push this new token for next evaluation
batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
@ -142,7 +135,7 @@ public class BatchedDecoding
}
// all streams are finished
if (batch.NativeBatch.n_tokens == 0)
if (batch.TokenCount == 0)
{
break;
}

121
LLama/Native/LLamaBatch.cs Normal file
View File

@ -0,0 +1,121 @@
using System;
namespace LLama.Native;
/// <summary>
/// A batch allows submitting multiple tokens to multiple sequences simultaneously
/// </summary>
public class LLamaBatch
{
private readonly byte[] _logits;
private readonly LLamaToken[] _tokens;
private readonly LLamaPos[] _positions;
private readonly int[] _sequenceIdCount;
private readonly LLamaSeqId[][] _sequenceIds;
private readonly IntPtr[] _sequenceIdsPtrs;
/// <summary>
/// The number of tokens in this batch
/// </summary>
public int TokenCount { get; private set; }
/// <summary>
/// Create a new batch for submitting inputs to llama.cpp
/// </summary>
/// <param name="n_tokens"></param>
/// <param name="n_seq_max"></param>
public LLamaBatch(int n_tokens, int n_seq_max)
{
_logits = new byte[n_tokens];
_tokens = new LLamaToken[n_tokens];
_positions = new LLamaPos[n_tokens];
_sequenceIdCount = new int[n_tokens];
_sequenceIdsPtrs = new IntPtr[_sequenceIdCount.Length];
_sequenceIds = new LLamaSeqId[n_tokens][];
for (var i = 0; i < _sequenceIds.Length; i++)
_sequenceIds[i] = new LLamaSeqId[n_seq_max];
}
internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
{
// This group holds all of the memory pins
var group = new GroupDisposable();
unsafe
{
batch = new LLamaNativeBatch
{
n_tokens = TokenCount,
logits = (byte*)group.Add(_logits.AsMemory().Pin()).Pointer,
n_seq_id = (int*)group.Add(_sequenceIdCount.AsMemory().Pin()).Pointer,
pos = (LLamaPos*)group.Add(_positions.AsMemory().Pin()).Pointer,
seq_id = (LLamaSeqId**)group.Add(_sequenceIdsPtrs.AsMemory().Pin()).Pointer,
// embd is not currently supported, so this is always null!
embd = null,
// Note that if embd is **not null** then this will be null!
tokens = (LLamaToken*)group.Add(_tokens.AsMemory().Pin()).Pointer,
};
// Create pointers to each of the arrays in turns
for (var i = 0; i < _sequenceIdsPtrs.Length; i++)
_sequenceIdsPtrs[i] = (IntPtr)group.Add(_sequenceIds[i].AsMemory().Pin()).Pointer;
}
return group;
}
/// <summary>
/// Add a single token to the batch at the same position in several sequences
/// </summary>
/// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks>
/// <param name="token">The token to add</param>
/// <param name="pos">The position to add it att</param>
/// <param name="sequences">The set of sequences to add this token to</param>
/// <param name="logits"></param>
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{
_tokens[TokenCount] = token;
_positions[TokenCount] = pos;
_sequenceIdCount[TokenCount] = sequences.Length;
for (var i = 0; i < sequences.Length; i++)
_sequenceIds[TokenCount][i] = sequences[i];
_logits[TokenCount] = Convert.ToByte(logits);
TokenCount++;
}
/// <summary>
/// Add a single token to the batch at a certain position for a single sequences
/// </summary>
/// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks>
/// <param name="token">The token to add</param>
/// <param name="pos">The position to add it att</param>
/// <param name="sequence">The sequence to add this token to</param>
/// <param name="logits"></param>
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
{
// Create a temporary span to contain 1 item without allocating
Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1];
sequences[0] = sequence;
// Add it
LLamaBatchAdd(token, pos, sequences, logits);
}
/// <summary>
/// Set TokenCount to zero for this batch
/// </summary>
public void LLamaBatchClear()
{
TokenCount = 0;
}
}

View File

@ -1,158 +0,0 @@
using System;
namespace LLama.Native;
/// <summary>
/// Input data for llama_decode. A llama_batch object can contain input about one or many sequences.
/// </summary>
public sealed class LLamaBatchSafeHandle
: SafeLLamaHandleBase
{
private readonly int _embd;
/// <summary>
/// Get the native llama_batch struct
/// </summary>
public LLamaNativeBatch NativeBatch;
/// <summary>
/// the token ids of the input (used when embd is NULL)
/// </summary>
public Span<LLamaToken> Token
{
get
{
unsafe
{
if (_embd != 0)
return new Span<LLamaToken>(null, 0);
else
return new Span<LLamaToken>(NativeBatch.token, NativeBatch.n_tokens);
}
}
}
/// <summary>
/// token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
/// </summary>
public Span<LLamaToken> Embed
{
get
{
unsafe
{
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float)
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
if (_embd != 0)
return new Span<LLamaToken>(NativeBatch.embd, NativeBatch.n_tokens * _embd);
else
return new Span<LLamaToken>(null, 0);
}
}
}
/// <summary>
/// the positions of the respective token in the sequence
/// </summary>
public Span<LLamaPos> Pos
{
get
{
unsafe
{
return new Span<LLamaPos>(NativeBatch.pos, NativeBatch.n_tokens);
}
}
}
/// <summary>
/// the sequence to which the respective token belongs
/// </summary>
public Span<LLamaSeqId> Sequence_ID
{
get
{
unsafe
{
return new Span<LLamaSeqId>(NativeBatch.seq_id, NativeBatch.n_tokens);
}
}
}
/// <summary>
/// if zero, the logits for the respective token will not be output
/// </summary>
public Span<byte> Logits
{
get
{
unsafe
{
return new Span<byte>(NativeBatch.logits, NativeBatch.n_tokens);
}
}
}
/// <summary>
/// Create a safe handle owning a `LLamaNativeBatch`
/// </summary>
/// <param name="batch"></param>
/// <param name="embd"></param>
public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd)
: base((nint)1)
{
_embd = embd;
NativeBatch = batch;
}
/// <summary>
/// Call `llama_batch_init` and create a new batch
/// </summary>
/// <param name="n_tokens"></param>
/// <param name="embd"></param>
/// <param name="n_seq_max"></param>
/// <returns></returns>
public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max)
{
var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max);
return new LLamaBatchSafeHandle(batch, embd);
}
/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_batch_free(NativeBatch);
NativeBatch = default;
SetHandle(IntPtr.Zero);
return true;
}
/// <summary>
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
/// </summary>
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{
unsafe
{
NativeBatch.token[NativeBatch.n_tokens] = token;
NativeBatch.pos[NativeBatch.n_tokens] = pos;
NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length;
for (var i = 0; i < sequences.Length; i++)
NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i];
NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits);
NativeBatch.n_tokens++;
}
}
/// <summary>
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825
/// </summary>
public void LLamaBatchClear()
{
NativeBatch.n_tokens = 0;
}
}

View File

@ -18,7 +18,7 @@ public unsafe struct LLamaNativeBatch
/// <summary>
/// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created
/// </summary>
public LLamaToken* token;
public LLamaToken* tokens;
/// <summary>
/// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created

View File

@ -8,6 +8,11 @@ namespace LLama.Native;
[StructLayout(LayoutKind.Sequential)]
public record struct LLamaSeqId
{
/// <summary>
/// LLamaSeqId with value 0
/// </summary>
public static readonly LLamaSeqId Zero = new LLamaSeqId(0);
/// <summary>
/// The raw value
/// </summary>

View File

@ -1,5 +1,4 @@
using System;
using System.Buffers;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
@ -198,9 +197,10 @@ namespace LLama.Native
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatchSafeHandle batch)
public int Decode(LLamaBatch batch)
{
return NativeApi.llama_decode(this, batch.NativeBatch);
using (batch.ToNativeBatch(out var nb))
return NativeApi.llama_decode(this, nb);
}
#region state