- Added various convenience overloads to `LLamaContext.Eval`

- Converted `SafeLLamaContextHandle` to take a `ReadOnlySpan` for Eval, narrower type better represents what's really needed
This commit is contained in:
Martin Evans 2023-08-18 00:39:25 +01:00
parent 4d0c044b9f
commit ae8ef17a4a
4 changed files with 76 additions and 7 deletions

View File

@ -29,7 +29,7 @@ namespace LLama.Examples.NewVersion
Console.Write("\nQuestion: ");
Console.ForegroundColor = ConsoleColor.Green;
string prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams))

View File

@ -1,6 +1,7 @@
using LLama.Exceptions;
using LLama.Native;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@ -384,6 +385,7 @@ namespace LLama
return candidates_p;
}
#region eval overloads
/// <summary>
///
/// </summary>
@ -391,7 +393,61 @@ namespace LLama
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount)
public int Eval(llama_token[] tokens, llama_token pastTokensCount)
{
return Eval(tokens.AsSpan(), pastTokensCount);
}
/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(List<llama_token> tokens, llama_token pastTokensCount)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
return Eval(span, pastTokensCount);
#else
// on netstandard2.0 we can't use collections marshal to get directly at the internal memory of
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
// avoid the copying.
var rented = ArrayPool<llama_token>.Shared.Rent(tokens.Count);
try
{
tokens.CopyTo(rented, 0);
return Eval(rented, pastTokensCount);
}
finally
{
ArrayPool<llama_token>.Shared.Return(rented);
}
#endif
}
/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(ReadOnlyMemory<llama_token> tokens, llama_token pastTokensCount)
{
return Eval(tokens.Span, pastTokensCount);
}
/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
public int Eval(ReadOnlySpan<llama_token> tokens, llama_token pastTokensCount)
{
int total = tokens.Length;
for(int i = 0; i < total; i += Params.BatchSize)
@ -402,7 +458,7 @@ namespace LLama
n_eval = Params.BatchSize;
}
if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads))
if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads))
{
_logger?.Log(nameof(LLamaContext), "Failed to eval.", ILLamaLogger.LogLevel.Error);
throw new RuntimeError("Failed to eval.");
@ -412,6 +468,7 @@ namespace LLama
}
return pastTokensCount;
}
#endregion
internal IEnumerable<string> GenerateResult(IEnumerable<llama_token> ids)
{
@ -419,6 +476,16 @@ namespace LLama
yield return _ctx.TokenToString(id, _encoding);
}
/// <summary>
/// Convert a token into a string
/// </summary>
/// <param name="token"></param>
/// <returns></returns>
public string TokenToString(llama_token token)
{
return NativeHandle.TokenToString(token, Encoding);
}
/// <inheritdoc />
public virtual void Dispose()
{

View File

@ -179,12 +179,14 @@ namespace LLama.Native
/// <param name="n_past">the number of tokens to use from previous eval calls</param>
/// <param name="n_threads"></param>
/// <returns>Returns true on success</returns>
public bool Eval(Memory<int> tokens, int n_past, int n_threads)
public bool Eval(ReadOnlySpan<int> tokens, int n_past, int n_threads)
{
using var pin = tokens.Pin();
unsafe
{
return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0;
fixed (int* pinned = tokens)
{
return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0;
}
}
}
}

View File

@ -37,7 +37,7 @@ namespace LLama
[Obsolete("Use SafeLLamaContextHandle Eval method instead")]
public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
{
var slice = tokens.AsMemory().Slice(startIndex, n_tokens);
var slice = tokens.AsSpan().Slice(startIndex, n_tokens);
return ctx.Eval(slice, n_past, n_threads) ? 0 : 1;
}