LLamaSharp/LLama/Native/LLamaBatchSafeHandle.cs

160 lines
4.3 KiB
C#

using System;
namespace LLama.Native;
using llama_token = Int32;
/// <summary>
/// Input data for llama_decode. A llama_batch object can contain input about one or many sequences.
/// </summary>
public sealed class LLamaBatchSafeHandle
: SafeLLamaHandleBase
{
private readonly int _embd;
/// <summary>
/// Get the native llama_batch struct
/// </summary>
public LLamaNativeBatch NativeBatch;
/// <summary>
/// the token ids of the input (used when embd is NULL)
/// </summary>
public Span<llama_token> Token
{
get
{
unsafe
{
if (_embd != 0)
return new Span<int>(null, 0);
else
return new Span<int>(NativeBatch.token, NativeBatch.n_tokens);
}
}
}
/// <summary>
/// token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
/// </summary>
public Span<llama_token> Embed
{
get
{
unsafe
{
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float)
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
if (_embd != 0)
return new Span<llama_token>(NativeBatch.embd, NativeBatch.n_tokens * _embd);
else
return new Span<llama_token>(null, 0);
}
}
}
/// <summary>
/// the positions of the respective token in the sequence
/// </summary>
public Span<LLamaPos> Pos
{
get
{
unsafe
{
return new Span<LLamaPos>(NativeBatch.pos, NativeBatch.n_tokens);
}
}
}
/// <summary>
/// the sequence to which the respective token belongs
/// </summary>
public Span<LLamaSeqId> Sequence_ID
{
get
{
unsafe
{
return new Span<LLamaSeqId>(NativeBatch.seq_id, NativeBatch.n_tokens);
}
}
}
/// <summary>
/// if zero, the logits for the respective token will not be output
/// </summary>
public Span<byte> Logits
{
get
{
unsafe
{
return new Span<byte>(NativeBatch.logits, NativeBatch.n_tokens);
}
}
}
/// <summary>
/// Create a safe handle owning a `LLamaNativeBatch`
/// </summary>
/// <param name="batch"></param>
/// <param name="embd"></param>
public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd)
: base((nint)1)
{
_embd = embd;
NativeBatch = batch;
}
/// <summary>
/// Call `llama_batch_init` and create a new batch
/// </summary>
/// <param name="n_tokens"></param>
/// <param name="embd"></param>
/// <param name="n_seq_max"></param>
/// <returns></returns>
public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max)
{
var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max);
return new LLamaBatchSafeHandle(batch, embd);
}
/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_batch_free(NativeBatch);
NativeBatch = default;
SetHandle(IntPtr.Zero);
return true;
}
/// <summary>
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
/// </summary>
public void LLamaBatchAdd(int token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{
unsafe
{
NativeBatch.token[NativeBatch.n_tokens] = token;
NativeBatch.pos[NativeBatch.n_tokens] = pos;
NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length;
for (var i = 0; i < sequences.Length; i++)
NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i];
NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits);
NativeBatch.n_tokens++;
}
}
/// <summary>
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825
/// </summary>
public void LLamaBatchClear()
{
NativeBatch.n_tokens = 0;
}
}