Reduced some uses of `NativeApi` in `BatchedDecoding` by adding some helper methods
This commit is contained in:
parent
b38e3f6fe2
commit
3c5547b2b7
|
@ -11,6 +11,10 @@ public class BatchedDecoding
|
|||
private const int n_parallel = 8;
|
||||
private const int n_len = 32;
|
||||
|
||||
private const int top_k = 40;
|
||||
private const float top_p = 0.9f;
|
||||
private const float temp = 0.4f;
|
||||
|
||||
public static async Task Run()
|
||||
{
|
||||
Console.Write("Please input your model path: ");
|
||||
|
@ -91,8 +95,8 @@ public class BatchedDecoding
|
|||
for (var i = 0; i < n_parallel; i++)
|
||||
streams[i] = new();
|
||||
|
||||
var eos = NativeApi.llama_token_eos(model.NativeHandle);
|
||||
var nl = NativeApi.llama_token_nl(model.NativeHandle);
|
||||
var eos = model.EndOfSentenceToken;
|
||||
var nl = model.NewlineToken;
|
||||
|
||||
var timer = new Stopwatch();
|
||||
timer.Start();
|
||||
|
@ -106,50 +110,34 @@ public class BatchedDecoding
|
|||
if (i_batch[i] < 0)
|
||||
continue;
|
||||
|
||||
var n_vocab = model.VocabCount;
|
||||
LLamaTokenDataArray candidates;
|
||||
unsafe
|
||||
{
|
||||
var n_vocab = model.VocabCount;
|
||||
var logits = NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]);
|
||||
|
||||
var candidates = new LLamaTokenData[n_vocab];
|
||||
for (var token_id = 0; token_id < n_vocab; token_id++)
|
||||
{
|
||||
candidates[token_id] = new LLamaTokenData
|
||||
{
|
||||
id = token_id,
|
||||
logit = logits[token_id]
|
||||
};
|
||||
}
|
||||
|
||||
var candidates_p = new LLamaTokenDataArray(candidates);
|
||||
using var pin = LLamaTokenDataArrayNative.Create(candidates_p, out var candidates_native);
|
||||
|
||||
const int top_k = 40;
|
||||
const float top_p = 0.9f;
|
||||
const float temp = 0.4f;
|
||||
|
||||
NativeApi.llama_sample_top_k(context.NativeHandle, ref candidates_native, top_k, 1);
|
||||
NativeApi.llama_sample_top_p(context.NativeHandle, ref candidates_native, top_p, 1);
|
||||
NativeApi.llama_sample_temperature(context.NativeHandle, ref candidates_native, temp);
|
||||
|
||||
var new_token_id = NativeApi.llama_sample_token(context.NativeHandle, ref candidates_native);
|
||||
|
||||
if (new_token_id == eos || new_token_id == nl)
|
||||
{
|
||||
i_batch[i] = -1;
|
||||
Console.WriteLine($"Completed Stream {i} early");
|
||||
continue;
|
||||
}
|
||||
|
||||
streams[i].Add(new_token_id);
|
||||
|
||||
i_batch[i] = batch.NativeBatch.n_tokens;
|
||||
|
||||
// push this new token for next evaluation
|
||||
llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true);
|
||||
|
||||
n_decode++;
|
||||
candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
|
||||
}
|
||||
using var pin = LLamaTokenDataArrayNative.Create(candidates, out var candidates_native);
|
||||
|
||||
candidates_native.TopK(context.NativeHandle, top_k);
|
||||
candidates_native.TopP(context.NativeHandle, top_p);
|
||||
candidates_native.Temperature(context.NativeHandle, temp);
|
||||
var new_token_id = candidates_native.SampleToken(context.NativeHandle);
|
||||
|
||||
if (new_token_id == eos || new_token_id == nl)
|
||||
{
|
||||
i_batch[i] = -1;
|
||||
Console.WriteLine($"Completed Stream {i} early");
|
||||
continue;
|
||||
}
|
||||
|
||||
streams[i].Add(new_token_id);
|
||||
|
||||
i_batch[i] = batch.NativeBatch.n_tokens;
|
||||
|
||||
// push this new token for next evaluation
|
||||
llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true);
|
||||
|
||||
n_decode++;
|
||||
}
|
||||
|
||||
// all streams are finished
|
||||
|
|
|
@ -38,6 +38,21 @@ namespace LLama
|
|||
/// </summary>
|
||||
public ulong ParameterCount => NativeHandle.ParameterCount;
|
||||
|
||||
/// <summary>
|
||||
/// Get the newline token for this model
|
||||
/// </summary>
|
||||
public int NewlineToken => NativeApi.llama_token_nl(NativeHandle);
|
||||
|
||||
/// <summary>
|
||||
/// Get the "end of sentence" token for this model
|
||||
/// </summary>
|
||||
public int EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle);
|
||||
|
||||
/// <summary>
|
||||
/// Get the "beginning of sentence" token for this model
|
||||
/// </summary>
|
||||
public int BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle);
|
||||
|
||||
/// <summary>
|
||||
/// Dimension of embedding vectors
|
||||
/// </summary>
|
||||
|
|
|
@ -96,5 +96,48 @@ namespace LLama.Native
|
|||
|
||||
return handle;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Perform TopK sampling, sorting the data and reducing the size to k
|
||||
/// </summary>
|
||||
/// <param name="context"></param>
|
||||
/// <param name="k">Number of tokens to keep</param>
|
||||
/// <param name="minKeep">Minimum number to keep</param>
|
||||
public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1)
|
||||
{
|
||||
NativeApi.llama_sample_top_k(context, ref this, k, minKeep);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Perform top p sampling, sorting the data and keeping only logits more likely than p
|
||||
/// </summary>
|
||||
/// <param name="context"></param>
|
||||
/// <param name="p"></param>
|
||||
/// <param name="minKeep"></param>
|
||||
public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
|
||||
{
|
||||
NativeApi.llama_sample_top_p(context, ref this, p, minKeep);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Apply temperature to logits
|
||||
/// </summary>
|
||||
/// <param name="context"></param>
|
||||
/// <param name="temp"></param>
|
||||
public void Temperature(SafeLLamaContextHandle context, float temp)
|
||||
{
|
||||
NativeApi.llama_sample_temperature(context, ref this, temp);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sample a token from the set of possible tokens
|
||||
/// </summary>
|
||||
/// <param name="context"></param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="NotImplementedException"></exception>
|
||||
public int SampleToken(SafeLLamaContextHandle context)
|
||||
{
|
||||
return NativeApi.llama_sample_token(context, ref this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue