Added methods to `SafeLLamaContextHandle` for KV cache manipulation

This commit is contained in:
Martin Evans 2024-01-23 16:16:02 +00:00
parent 8dfd07f67b
commit 92b9bbe779
5 changed files with 166 additions and 39 deletions

View File

@ -2,6 +2,7 @@
using System.Text; using System.Text;
using LLama.Common; using LLama.Common;
using LLama.Native; using LLama.Native;
using LLama.Sampling;
namespace LLama.Examples.Examples; namespace LLama.Examples.Examples;
@ -14,10 +15,6 @@ public class BatchedDecoding
private const int n_parallel = 8; private const int n_parallel = 8;
private const int n_len = 32; private const int n_len = 32;
private const int top_k = 80;
private const float top_p = 0.8f;
private const float temp = 0.75f;
public static async Task Run() public static async Task Run()
{ {
Console.Write("Please input your model path: "); Console.Write("Please input your model path: ");
@ -55,10 +52,9 @@ public class BatchedDecoding
var batch = new LLamaBatch(); var batch = new LLamaBatch();
// evaluate the initial prompt // evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++) batch.AddRange(prompt_tokens, 0, LLamaSeqId.Zero, true);
batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);
if (await context.DecodeAsync(batch) != 0) if (await context.DecodeAsync(batch) != DecodeResult.Ok)
{ {
await Console.Error.WriteLineAsync("llama_decode failed"); await Console.Error.WriteLineAsync("llama_decode failed");
return; return;
@ -68,7 +64,7 @@ public class BatchedDecoding
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (var i = 1; i < n_parallel; ++i) for (var i = 1; i < n_parallel; ++i)
{ {
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount); context.NativeHandle.KvCacheSequenceCopy((LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
} }
if (n_parallel > 1) if (n_parallel > 1)
@ -83,16 +79,22 @@ public class BatchedDecoding
for (var i = 0; i < n_parallel; i++) for (var i = 0; i < n_parallel; i++)
i_batch.Add(batch.TokenCount - 1); i_batch.Add(batch.TokenCount - 1);
// Create per-stream decoder and sampler
var decoders = new StreamingTokenDecoder[n_parallel];
var samplers = new ISamplingPipeline[n_parallel];
for (var i = 0; i < n_parallel; i++)
{
decoders[i] = new StreamingTokenDecoder(context);
samplers[i] = new DefaultSamplingPipeline
{
Temperature = 0.1f + (float)i / n_parallel,
MinP = 0.25f,
};
}
var n_cur = batch.TokenCount; var n_cur = batch.TokenCount;
var n_decode = 0; var n_decode = 0;
var streams = new StreamingTokenDecoder[n_parallel];
for (var i = 0; i < n_parallel; i++)
streams[i] = new StreamingTokenDecoder(context);
var eos = model.EndOfSentenceToken;
var nl = model.NewlineToken;
var timer = new Stopwatch(); var timer = new Stopwatch();
timer.Start(); timer.Start();
while (n_cur <= n_len) while (n_cur <= n_len)
@ -105,31 +107,33 @@ public class BatchedDecoding
if (i_batch[i] < 0) if (i_batch[i] < 0)
continue; continue;
var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i])); // Use the sampling pipeline to select a token
var new_token_id = samplers[i].Sample(
context.NativeHandle,
context.NativeHandle.GetLogitsIth(i_batch[i]),
Array.Empty<LLamaToken>()
);
candidates.TopK(context.NativeHandle, top_k); // Finish this stream early if necessary
candidates.TopP(context.NativeHandle, top_p); if (new_token_id == model.EndOfSentenceToken || new_token_id == model.NewlineToken)
candidates.Temperature(context.NativeHandle, temp);
var new_token_id = candidates.SampleToken(context.NativeHandle);
if (new_token_id == eos || new_token_id == nl)
{ {
i_batch[i] = -1; i_batch[i] = -1;
Console.WriteLine($"Completed Stream {i} early"); Console.WriteLine($"Completed Stream {i} early");
continue; continue;
} }
streams[i].Add(new_token_id); // Add this token to the decoder, so it will be turned into text
decoders[i].Add(new_token_id);
i_batch[i] = batch.TokenCount; i_batch[i] = batch.TokenCount;
// push this new token for next evaluation // push this new token for next evaluation
batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); batch.Add(new_token_id, n_cur, (LLamaSeqId)i, true);
n_decode++; n_decode++;
} }
// all streams are finished // Check if all streams are finished
if (batch.TokenCount == 0) if (batch.TokenCount == 0)
{ {
break; break;
@ -152,7 +156,7 @@ public class BatchedDecoding
Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second"); Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");
var index = 0; var index = 0;
foreach (var stream in streams) foreach (var stream in decoders)
{ {
var text = stream.Read(); var text = stream.Read();

View File

@ -1,4 +1,6 @@
using System; using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
namespace LLama.Native; namespace LLama.Native;
@ -22,14 +24,14 @@ public class LLamaBatch
public int TokenCount { get; private set; } public int TokenCount { get; private set; }
/// <summary> /// <summary>
/// Maximum number of tokens that can be added to this batch /// Maximum number of tokens that can be added to this batch (automatically grows if exceeded)
/// </summary> /// </summary>
private int TokenCapacity { get; set; } private int TokenCapacity { get; set; }
/// <summary> /// <summary>
/// Maximum number of sequences a token can be assigned to /// Maximum number of sequences a token can be assigned to (automatically grows if exceeded)
/// </summary> /// </summary>
public int MaxSequences { get; private set; } public int SequenceCapacity { get; private set; }
/// <summary> /// <summary>
/// Create a new batch for submitting inputs to llama.cpp /// Create a new batch for submitting inputs to llama.cpp
@ -40,7 +42,7 @@ public class LLamaBatch
const int n_tokens = 128; const int n_tokens = 128;
const int n_seq_max = 1; const int n_seq_max = 1;
MaxSequences = n_seq_max; SequenceCapacity = n_seq_max;
TokenCapacity = n_tokens; TokenCapacity = n_tokens;
_logits = new byte[n_tokens]; _logits = new byte[n_tokens];
@ -52,9 +54,10 @@ public class LLamaBatch
_sequenceIds = new LLamaSeqId[n_tokens][]; _sequenceIds = new LLamaSeqId[n_tokens][];
for (var i = 0; i < _sequenceIds.Length; i++) for (var i = 0; i < _sequenceIds.Length; i++)
_sequenceIds[i] = new LLamaSeqId[MaxSequences]; _sequenceIds[i] = new LLamaSeqId[SequenceCapacity];
} }
#region grow
private void GrowTokenCapacity() private void GrowTokenCapacity()
{ {
var n_tokens = TokenCount * 2; var n_tokens = TokenCount * 2;
@ -73,18 +76,19 @@ public class LLamaBatch
// Growing the array filled elements with null, temporarily violating the nullability contract! // Growing the array filled elements with null, temporarily violating the nullability contract!
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
if (_sequenceIds[i] == null) if (_sequenceIds[i] == null)
_sequenceIds[i] = new LLamaSeqId[MaxSequences]; _sequenceIds[i] = new LLamaSeqId[SequenceCapacity];
} }
} }
private void GrowMaxSequences(int atLeast) private void GrowMaxSequences(int atLeast)
{ {
var n_seq = Math.Max(MaxSequences * 2, atLeast); var n_seq = Math.Max(SequenceCapacity * 2, atLeast);
MaxSequences = n_seq; SequenceCapacity = n_seq;
for (var i = 0; i < _sequenceIds.Length; i++) for (var i = 0; i < _sequenceIds.Length; i++)
Array.Resize(ref _sequenceIds[i], MaxSequences); Array.Resize(ref _sequenceIds[i], SequenceCapacity);
} }
#endregion
internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
{ {
@ -117,6 +121,7 @@ public class LLamaBatch
return group; return group;
} }
#region add
/// <summary> /// <summary>
/// Add a single token to the batch at the same position in several sequences /// Add a single token to the batch at the same position in several sequences
/// </summary> /// </summary>
@ -129,7 +134,7 @@ public class LLamaBatch
{ {
if (TokenCount == TokenCapacity) if (TokenCount == TokenCapacity)
GrowTokenCapacity(); GrowTokenCapacity();
if (sequences.Length > MaxSequences) if (sequences.Length > SequenceCapacity)
GrowMaxSequences(sequences.Length); GrowMaxSequences(sequences.Length);
_tokens[TokenCount] = token; _tokens[TokenCount] = token;
@ -144,6 +149,37 @@ public class LLamaBatch
TokenCount++; TokenCount++;
} }
/// <summary>
/// Add a single token to the batch at the same position in several sequences
/// </summary>
/// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks>
/// <param name="token">The token to add</param>
/// <param name="pos">The position to add it att</param>
/// <param name="sequences">The set of sequences to add this token to</param>
/// <param name="logits"></param>
public void Add(LLamaToken token, LLamaPos pos, List<LLamaSeqId> sequences, bool logits)
{
#if NET5_0_OR_GREATER
var seqSpan = CollectionsMarshal.AsSpan(sequences);
Add(token, pos, seqSpan, logits);
#else
// on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
// avoid the copying.
var rented = System.Buffers.ArrayPool<LLamaSeqId>.Shared.Rent(sequences.Count);
try
{
sequences.CopyTo(rented, 0);
Add(token, pos, rented.AsSpan(0, sequences.Count), logits);
}
finally
{
System.Buffers.ArrayPool<LLamaSeqId>.Shared.Return(rented);
}
#endif
}
/// <summary> /// <summary>
/// Add a single token to the batch at a certain position for a single sequences /// Add a single token to the batch at a certain position for a single sequences
/// </summary> /// </summary>
@ -162,6 +198,23 @@ public class LLamaBatch
Add(token, pos, sequences, logits); Add(token, pos, sequences, logits);
} }
/// <summary>
/// Add a range of tokens to a single sequence, start at the given position.
/// </summary>
/// <param name="tokens">The tokens to add</param>
/// <param name="start">The starting position to add tokens at</param>
/// <param name="sequence">The sequence to add this token to</param>
/// <param name="logitsLast">Whether the final token should generate logits</param>
public void AddRange(ReadOnlySpan<LLamaToken> tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast)
{
for (var i = 0; i < tokens.Length; i++)
{
var logits = (i == tokens.Length - 1) & logitsLast;
Add(tokens[i], start.Value + i, sequence, logits);
}
}
#endregion
/// <summary> /// <summary>
/// Set TokenCount to zero for this batch /// Set TokenCount to zero for this batch
/// </summary> /// </summary>

View File

@ -112,7 +112,7 @@ public class LLamaKvCacheViewSafeHandle
} }
/// <summary> /// <summary>
/// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be countered multiple times /// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be counted multiple times
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
public int CountTokens() public int CountTokens()

View File

@ -422,6 +422,6 @@ namespace LLama.Native
/// <param name="n_threads_batch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param> /// <param name="n_threads_batch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
/// <returns></returns> /// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); public static extern void llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);
} }
} }

View File

@ -167,8 +167,9 @@ namespace LLama.Native
{ {
return ThrowIfDisposed().TokenToSpan(token, dest); return ThrowIfDisposed().TokenToSpan(token, dest);
} }
#endregion #endregion
#region infer
/// <summary> /// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token. /// Run the llama inference to obtain the logits and probabilities for the next token.
/// </summary> /// </summary>
@ -202,6 +203,7 @@ namespace LLama.Native
using (batch.ToNativeBatch(out var nb)) using (batch.ToNativeBatch(out var nb))
return NativeApi.llama_decode(this, nb); return NativeApi.llama_decode(this, nb);
} }
#endregion
#region state #region state
/// <summary> /// <summary>
@ -275,5 +277,73 @@ namespace LLama.Native
{ {
NativeApi.llama_set_rng_seed(this, seed); NativeApi.llama_set_rng_seed(this, seed);
} }
/// <summary>
/// Set the number of threads used for decoding
/// </summary>
/// <param name="threads">n_threads is the number of threads used for generation (single token)</param>
/// <param name="threadsBatch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
public void SetThreads(uint threads, uint threadsBatch)
{
NativeApi.llama_set_n_threads(this, threads, threadsBatch);
}
#region KV Cache Management
/// <summary>
/// Clear the KV cache
/// </summary>
public void KvCacheClear()
{
NativeApi.llama_kv_cache_clear(this);
}
/// <summary>
/// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
/// </summary>
/// <param name="seq"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
public void KvCacheRemove(LLamaSeqId seq, LLamaPos p0, LLamaPos p1)
{
NativeApi.llama_kv_cache_seq_rm(this, seq, p0, p1);
}
/// <summary>
/// Copy all tokens that belong to the specified sequence to another sequence. Note that
/// this does not allocate extra KV cache memory - it simply assigns the tokens to the
/// new sequence
/// </summary>
/// <param name="src"></param>
/// <param name="dest"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
public void KvCacheSequenceCopy(LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1)
{
NativeApi.llama_kv_cache_seq_cp(this, src, dest, p0, p1);
}
/// <summary>
/// Removes all tokens that do not belong to the specified sequence
/// </summary>
/// <param name="seq"></param>
public void KvCacheSequenceKeep(LLamaSeqId seq)
{
NativeApi.llama_kv_cache_seq_keep(this, seq);
}
/// <summary>
/// Adds relative position "delta" to all tokens that belong to the specified sequence
/// and have positions in [p0, p1. If the KV cache is RoPEd, the KV data is updated
/// accordingly
/// </summary>
/// <param name="seq"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
/// <param name="delta"></param>
public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta)
{
NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta);
}
#endregion
} }
} }