Reduced some uses of `NativeApi` in `BatchedDecoding` by adding some helper methods

This commit is contained in:
Martin Evans 2023-10-25 22:29:19 +01:00
parent b38e3f6fe2
commit 3c5547b2b7
3 changed files with 89 additions and 43 deletions

View File

@ -11,6 +11,10 @@ public class BatchedDecoding
private const int n_parallel = 8;
private const int n_len = 32;
private const int top_k = 40;
private const float top_p = 0.9f;
private const float temp = 0.4f;
public static async Task Run()
{
Console.Write("Please input your model path: ");
@ -91,8 +95,8 @@ public class BatchedDecoding
for (var i = 0; i < n_parallel; i++)
streams[i] = new();
var eos = NativeApi.llama_token_eos(model.NativeHandle);
var nl = NativeApi.llama_token_nl(model.NativeHandle);
var eos = model.EndOfSentenceToken;
var nl = model.NewlineToken;
var timer = new Stopwatch();
timer.Start();
@ -106,50 +110,34 @@ public class BatchedDecoding
if (i_batch[i] < 0)
continue;
var n_vocab = model.VocabCount;
LLamaTokenDataArray candidates;
unsafe
{
var n_vocab = model.VocabCount;
var logits = NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]);
var candidates = new LLamaTokenData[n_vocab];
for (var token_id = 0; token_id < n_vocab; token_id++)
{
candidates[token_id] = new LLamaTokenData
{
id = token_id,
logit = logits[token_id]
};
}
var candidates_p = new LLamaTokenDataArray(candidates);
using var pin = LLamaTokenDataArrayNative.Create(candidates_p, out var candidates_native);
const int top_k = 40;
const float top_p = 0.9f;
const float temp = 0.4f;
NativeApi.llama_sample_top_k(context.NativeHandle, ref candidates_native, top_k, 1);
NativeApi.llama_sample_top_p(context.NativeHandle, ref candidates_native, top_p, 1);
NativeApi.llama_sample_temperature(context.NativeHandle, ref candidates_native, temp);
var new_token_id = NativeApi.llama_sample_token(context.NativeHandle, ref candidates_native);
if (new_token_id == eos || new_token_id == nl)
{
i_batch[i] = -1;
Console.WriteLine($"Completed Stream {i} early");
continue;
}
streams[i].Add(new_token_id);
i_batch[i] = batch.NativeBatch.n_tokens;
// push this new token for next evaluation
llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true);
n_decode++;
candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
}
using var pin = LLamaTokenDataArrayNative.Create(candidates, out var candidates_native);
candidates_native.TopK(context.NativeHandle, top_k);
candidates_native.TopP(context.NativeHandle, top_p);
candidates_native.Temperature(context.NativeHandle, temp);
var new_token_id = candidates_native.SampleToken(context.NativeHandle);
if (new_token_id == eos || new_token_id == nl)
{
i_batch[i] = -1;
Console.WriteLine($"Completed Stream {i} early");
continue;
}
streams[i].Add(new_token_id);
i_batch[i] = batch.NativeBatch.n_tokens;
// push this new token for next evaluation
llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true);
n_decode++;
}
// all streams are finished

View File

@ -38,6 +38,21 @@ namespace LLama
/// </summary>
public ulong ParameterCount => NativeHandle.ParameterCount;
/// <summary>
/// Get the newline token for this model
/// </summary>
public int NewlineToken => NativeApi.llama_token_nl(NativeHandle);
/// <summary>
/// Get the "end of sentence" token for this model
/// </summary>
public int EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle);
/// <summary>
/// Get the "beginning of sentence" token for this model
/// </summary>
public int BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle);
/// <summary>
/// Dimension of embedding vectors
/// </summary>

View File

@ -96,5 +96,48 @@ namespace LLama.Native
return handle;
}
/// <summary>
/// Perform TopK sampling, sorting the data and reducing the size to k
/// </summary>
/// <param name="context"></param>
/// <param name="k">Number of tokens to keep</param>
/// <param name="minKeep">Minimum number to keep</param>
public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1)
{
NativeApi.llama_sample_top_k(context, ref this, k, minKeep);
}
/// <summary>
/// Perform top p sampling, sorting the data and keeping only logits more likely than p
/// </summary>
/// <param name="context"></param>
/// <param name="p"></param>
/// <param name="minKeep"></param>
public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
{
NativeApi.llama_sample_top_p(context, ref this, p, minKeep);
}
/// <summary>
/// Apply temperature to logits
/// </summary>
/// <param name="context"></param>
/// <param name="temp"></param>
public void Temperature(SafeLLamaContextHandle context, float temp)
{
NativeApi.llama_sample_temperature(context, ref this, temp);
}
/// <summary>
/// Sample a token from the set of possible tokens
/// </summary>
/// <param name="context"></param>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
public int SampleToken(SafeLLamaContextHandle context)
{
return NativeApi.llama_sample_token(context, ref this);
}
}
}