Added methods to `SafeLLamaContextHandle` for KV cache manipulation
This commit is contained in:
parent
8dfd07f67b
commit
92b9bbe779
|
@ -2,6 +2,7 @@
|
|||
using System.Text;
|
||||
using LLama.Common;
|
||||
using LLama.Native;
|
||||
using LLama.Sampling;
|
||||
|
||||
namespace LLama.Examples.Examples;
|
||||
|
||||
|
@ -14,10 +15,6 @@ public class BatchedDecoding
|
|||
private const int n_parallel = 8;
|
||||
private const int n_len = 32;
|
||||
|
||||
private const int top_k = 80;
|
||||
private const float top_p = 0.8f;
|
||||
private const float temp = 0.75f;
|
||||
|
||||
public static async Task Run()
|
||||
{
|
||||
Console.Write("Please input your model path: ");
|
||||
|
@ -55,10 +52,9 @@ public class BatchedDecoding
|
|||
var batch = new LLamaBatch();
|
||||
|
||||
// evaluate the initial prompt
|
||||
for (var i = 0; i < prompt_tokens.Length; i++)
|
||||
batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);
|
||||
batch.AddRange(prompt_tokens, 0, LLamaSeqId.Zero, true);
|
||||
|
||||
if (await context.DecodeAsync(batch) != 0)
|
||||
if (await context.DecodeAsync(batch) != DecodeResult.Ok)
|
||||
{
|
||||
await Console.Error.WriteLineAsync("llama_decode failed");
|
||||
return;
|
||||
|
@ -68,7 +64,7 @@ public class BatchedDecoding
|
|||
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
|
||||
for (var i = 1; i < n_parallel; ++i)
|
||||
{
|
||||
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
|
||||
context.NativeHandle.KvCacheSequenceCopy((LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
|
||||
}
|
||||
|
||||
if (n_parallel > 1)
|
||||
|
@ -83,16 +79,22 @@ public class BatchedDecoding
|
|||
for (var i = 0; i < n_parallel; i++)
|
||||
i_batch.Add(batch.TokenCount - 1);
|
||||
|
||||
// Create per-stream decoder and sampler
|
||||
var decoders = new StreamingTokenDecoder[n_parallel];
|
||||
var samplers = new ISamplingPipeline[n_parallel];
|
||||
for (var i = 0; i < n_parallel; i++)
|
||||
{
|
||||
decoders[i] = new StreamingTokenDecoder(context);
|
||||
samplers[i] = new DefaultSamplingPipeline
|
||||
{
|
||||
Temperature = 0.1f + (float)i / n_parallel,
|
||||
MinP = 0.25f,
|
||||
};
|
||||
}
|
||||
|
||||
var n_cur = batch.TokenCount;
|
||||
var n_decode = 0;
|
||||
|
||||
var streams = new StreamingTokenDecoder[n_parallel];
|
||||
for (var i = 0; i < n_parallel; i++)
|
||||
streams[i] = new StreamingTokenDecoder(context);
|
||||
|
||||
var eos = model.EndOfSentenceToken;
|
||||
var nl = model.NewlineToken;
|
||||
|
||||
var timer = new Stopwatch();
|
||||
timer.Start();
|
||||
while (n_cur <= n_len)
|
||||
|
@ -105,31 +107,33 @@ public class BatchedDecoding
|
|||
if (i_batch[i] < 0)
|
||||
continue;
|
||||
|
||||
var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i]));
|
||||
// Use the sampling pipeline to select a token
|
||||
var new_token_id = samplers[i].Sample(
|
||||
context.NativeHandle,
|
||||
context.NativeHandle.GetLogitsIth(i_batch[i]),
|
||||
Array.Empty<LLamaToken>()
|
||||
);
|
||||
|
||||
candidates.TopK(context.NativeHandle, top_k);
|
||||
candidates.TopP(context.NativeHandle, top_p);
|
||||
candidates.Temperature(context.NativeHandle, temp);
|
||||
var new_token_id = candidates.SampleToken(context.NativeHandle);
|
||||
|
||||
if (new_token_id == eos || new_token_id == nl)
|
||||
// Finish this stream early if necessary
|
||||
if (new_token_id == model.EndOfSentenceToken || new_token_id == model.NewlineToken)
|
||||
{
|
||||
i_batch[i] = -1;
|
||||
Console.WriteLine($"Completed Stream {i} early");
|
||||
continue;
|
||||
}
|
||||
|
||||
streams[i].Add(new_token_id);
|
||||
// Add this token to the decoder, so it will be turned into text
|
||||
decoders[i].Add(new_token_id);
|
||||
|
||||
i_batch[i] = batch.TokenCount;
|
||||
|
||||
// push this new token for next evaluation
|
||||
batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
|
||||
batch.Add(new_token_id, n_cur, (LLamaSeqId)i, true);
|
||||
|
||||
n_decode++;
|
||||
}
|
||||
|
||||
// all streams are finished
|
||||
// Check if all streams are finished
|
||||
if (batch.TokenCount == 0)
|
||||
{
|
||||
break;
|
||||
|
@ -152,7 +156,7 @@ public class BatchedDecoding
|
|||
Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");
|
||||
|
||||
var index = 0;
|
||||
foreach (var stream in streams)
|
||||
foreach (var stream in decoders)
|
||||
{
|
||||
var text = stream.Read();
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace LLama.Native;
|
||||
|
||||
|
@ -22,14 +24,14 @@ public class LLamaBatch
|
|||
public int TokenCount { get; private set; }
|
||||
|
||||
/// <summary>
|
||||
/// Maximum number of tokens that can be added to this batch
|
||||
/// Maximum number of tokens that can be added to this batch (automatically grows if exceeded)
|
||||
/// </summary>
|
||||
private int TokenCapacity { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Maximum number of sequences a token can be assigned to
|
||||
/// Maximum number of sequences a token can be assigned to (automatically grows if exceeded)
|
||||
/// </summary>
|
||||
public int MaxSequences { get; private set; }
|
||||
public int SequenceCapacity { get; private set; }
|
||||
|
||||
/// <summary>
|
||||
/// Create a new batch for submitting inputs to llama.cpp
|
||||
|
@ -40,7 +42,7 @@ public class LLamaBatch
|
|||
const int n_tokens = 128;
|
||||
const int n_seq_max = 1;
|
||||
|
||||
MaxSequences = n_seq_max;
|
||||
SequenceCapacity = n_seq_max;
|
||||
TokenCapacity = n_tokens;
|
||||
|
||||
_logits = new byte[n_tokens];
|
||||
|
@ -52,9 +54,10 @@ public class LLamaBatch
|
|||
|
||||
_sequenceIds = new LLamaSeqId[n_tokens][];
|
||||
for (var i = 0; i < _sequenceIds.Length; i++)
|
||||
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
|
||||
_sequenceIds[i] = new LLamaSeqId[SequenceCapacity];
|
||||
}
|
||||
|
||||
#region grow
|
||||
private void GrowTokenCapacity()
|
||||
{
|
||||
var n_tokens = TokenCount * 2;
|
||||
|
@ -73,18 +76,19 @@ public class LLamaBatch
|
|||
// Growing the array filled elements with null, temporarily violating the nullability contract!
|
||||
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
|
||||
if (_sequenceIds[i] == null)
|
||||
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
|
||||
_sequenceIds[i] = new LLamaSeqId[SequenceCapacity];
|
||||
}
|
||||
}
|
||||
|
||||
private void GrowMaxSequences(int atLeast)
|
||||
{
|
||||
var n_seq = Math.Max(MaxSequences * 2, atLeast);
|
||||
MaxSequences = n_seq;
|
||||
var n_seq = Math.Max(SequenceCapacity * 2, atLeast);
|
||||
SequenceCapacity = n_seq;
|
||||
|
||||
for (var i = 0; i < _sequenceIds.Length; i++)
|
||||
Array.Resize(ref _sequenceIds[i], MaxSequences);
|
||||
Array.Resize(ref _sequenceIds[i], SequenceCapacity);
|
||||
}
|
||||
#endregion
|
||||
|
||||
internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
|
||||
{
|
||||
|
@ -117,6 +121,7 @@ public class LLamaBatch
|
|||
return group;
|
||||
}
|
||||
|
||||
#region add
|
||||
/// <summary>
|
||||
/// Add a single token to the batch at the same position in several sequences
|
||||
/// </summary>
|
||||
|
@ -129,7 +134,7 @@ public class LLamaBatch
|
|||
{
|
||||
if (TokenCount == TokenCapacity)
|
||||
GrowTokenCapacity();
|
||||
if (sequences.Length > MaxSequences)
|
||||
if (sequences.Length > SequenceCapacity)
|
||||
GrowMaxSequences(sequences.Length);
|
||||
|
||||
_tokens[TokenCount] = token;
|
||||
|
@ -144,6 +149,37 @@ public class LLamaBatch
|
|||
TokenCount++;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add a single token to the batch at the same position in several sequences
|
||||
/// </summary>
|
||||
/// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks>
|
||||
/// <param name="token">The token to add</param>
|
||||
/// <param name="pos">The position to add it att</param>
|
||||
/// <param name="sequences">The set of sequences to add this token to</param>
|
||||
/// <param name="logits"></param>
|
||||
public void Add(LLamaToken token, LLamaPos pos, List<LLamaSeqId> sequences, bool logits)
|
||||
{
|
||||
#if NET5_0_OR_GREATER
|
||||
var seqSpan = CollectionsMarshal.AsSpan(sequences);
|
||||
Add(token, pos, seqSpan, logits);
|
||||
#else
|
||||
// on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of
|
||||
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
|
||||
// avoid the copying.
|
||||
|
||||
var rented = System.Buffers.ArrayPool<LLamaSeqId>.Shared.Rent(sequences.Count);
|
||||
try
|
||||
{
|
||||
sequences.CopyTo(rented, 0);
|
||||
Add(token, pos, rented.AsSpan(0, sequences.Count), logits);
|
||||
}
|
||||
finally
|
||||
{
|
||||
System.Buffers.ArrayPool<LLamaSeqId>.Shared.Return(rented);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add a single token to the batch at a certain position for a single sequences
|
||||
/// </summary>
|
||||
|
@ -162,6 +198,23 @@ public class LLamaBatch
|
|||
Add(token, pos, sequences, logits);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add a range of tokens to a single sequence, start at the given position.
|
||||
/// </summary>
|
||||
/// <param name="tokens">The tokens to add</param>
|
||||
/// <param name="start">The starting position to add tokens at</param>
|
||||
/// <param name="sequence">The sequence to add this token to</param>
|
||||
/// <param name="logitsLast">Whether the final token should generate logits</param>
|
||||
public void AddRange(ReadOnlySpan<LLamaToken> tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast)
|
||||
{
|
||||
for (var i = 0; i < tokens.Length; i++)
|
||||
{
|
||||
var logits = (i == tokens.Length - 1) & logitsLast;
|
||||
Add(tokens[i], start.Value + i, sequence, logits);
|
||||
}
|
||||
}
|
||||
#endregion
|
||||
|
||||
/// <summary>
|
||||
/// Set TokenCount to zero for this batch
|
||||
/// </summary>
|
||||
|
|
|
@ -112,7 +112,7 @@ public class LLamaKvCacheViewSafeHandle
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be countered multiple times
|
||||
/// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be counted multiple times
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public int CountTokens()
|
||||
|
|
|
@ -422,6 +422,6 @@ namespace LLama.Native
|
|||
/// <param name="n_threads_batch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
|
||||
/// <returns></returns>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);
|
||||
public static extern void llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -167,8 +167,9 @@ namespace LLama.Native
|
|||
{
|
||||
return ThrowIfDisposed().TokenToSpan(token, dest);
|
||||
}
|
||||
#endregion
|
||||
#endregion
|
||||
|
||||
#region infer
|
||||
/// <summary>
|
||||
/// Run the llama inference to obtain the logits and probabilities for the next token.
|
||||
/// </summary>
|
||||
|
@ -202,6 +203,7 @@ namespace LLama.Native
|
|||
using (batch.ToNativeBatch(out var nb))
|
||||
return NativeApi.llama_decode(this, nb);
|
||||
}
|
||||
#endregion
|
||||
|
||||
#region state
|
||||
/// <summary>
|
||||
|
@ -275,5 +277,73 @@ namespace LLama.Native
|
|||
{
|
||||
NativeApi.llama_set_rng_seed(this, seed);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Set the number of threads used for decoding
|
||||
/// </summary>
|
||||
/// <param name="threads">n_threads is the number of threads used for generation (single token)</param>
|
||||
/// <param name="threadsBatch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
|
||||
public void SetThreads(uint threads, uint threadsBatch)
|
||||
{
|
||||
NativeApi.llama_set_n_threads(this, threads, threadsBatch);
|
||||
}
|
||||
|
||||
#region KV Cache Management
|
||||
/// <summary>
|
||||
/// Clear the KV cache
|
||||
/// </summary>
|
||||
public void KvCacheClear()
|
||||
{
|
||||
NativeApi.llama_kv_cache_clear(this);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
/// </summary>
|
||||
/// <param name="seq"></param>
|
||||
/// <param name="p0"></param>
|
||||
/// <param name="p1"></param>
|
||||
public void KvCacheRemove(LLamaSeqId seq, LLamaPos p0, LLamaPos p1)
|
||||
{
|
||||
NativeApi.llama_kv_cache_seq_rm(this, seq, p0, p1);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Copy all tokens that belong to the specified sequence to another sequence. Note that
|
||||
/// this does not allocate extra KV cache memory - it simply assigns the tokens to the
|
||||
/// new sequence
|
||||
/// </summary>
|
||||
/// <param name="src"></param>
|
||||
/// <param name="dest"></param>
|
||||
/// <param name="p0"></param>
|
||||
/// <param name="p1"></param>
|
||||
public void KvCacheSequenceCopy(LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1)
|
||||
{
|
||||
NativeApi.llama_kv_cache_seq_cp(this, src, dest, p0, p1);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Removes all tokens that do not belong to the specified sequence
|
||||
/// </summary>
|
||||
/// <param name="seq"></param>
|
||||
public void KvCacheSequenceKeep(LLamaSeqId seq)
|
||||
{
|
||||
NativeApi.llama_kv_cache_seq_keep(this, seq);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds relative position "delta" to all tokens that belong to the specified sequence
|
||||
/// and have positions in [p0, p1. If the KV cache is RoPEd, the KV data is updated
|
||||
/// accordingly
|
||||
/// </summary>
|
||||
/// <param name="seq"></param>
|
||||
/// <param name="p0"></param>
|
||||
/// <param name="p1"></param>
|
||||
/// <param name="delta"></param>
|
||||
public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta)
|
||||
{
|
||||
NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta);
|
||||
}
|
||||
#endregion
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue