- Added various convenience overloads to `LLamaContext.Eval`
- Converted `SafeLLamaContextHandle` to take a `ReadOnlySpan` for Eval, narrower type better represents what's really needed
This commit is contained in:
parent
4d0c044b9f
commit
ae8ef17a4a
|
@ -29,7 +29,7 @@ namespace LLama.Examples.NewVersion
|
|||
Console.Write("\nQuestion: ");
|
||||
Console.ForegroundColor = ConsoleColor.Green;
|
||||
string prompt = Console.ReadLine();
|
||||
Console.ForegroundColor = ConsoleColor.White;
|
||||
Console.ForegroundColor = ConsoleColor.White;
|
||||
Console.Write("Answer: ");
|
||||
prompt = $"Question: {prompt.Trim()} Answer: ";
|
||||
foreach (var text in ex.Infer(prompt, inferenceParams))
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
using LLama.Exceptions;
|
||||
using LLama.Native;
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
|
@ -384,6 +385,7 @@ namespace LLama
|
|||
return candidates_p;
|
||||
}
|
||||
|
||||
#region eval overloads
|
||||
/// <summary>
|
||||
///
|
||||
/// </summary>
|
||||
|
@ -391,7 +393,61 @@ namespace LLama
|
|||
/// <param name="pastTokensCount"></param>
|
||||
/// <returns>The updated `pastTokensCount`.</returns>
|
||||
/// <exception cref="RuntimeError"></exception>
|
||||
public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount)
|
||||
public int Eval(llama_token[] tokens, llama_token pastTokensCount)
|
||||
{
|
||||
return Eval(tokens.AsSpan(), pastTokensCount);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
///
|
||||
/// </summary>
|
||||
/// <param name="tokens"></param>
|
||||
/// <param name="pastTokensCount"></param>
|
||||
/// <returns>The updated `pastTokensCount`.</returns>
|
||||
/// <exception cref="RuntimeError"></exception>
|
||||
public int Eval(List<llama_token> tokens, llama_token pastTokensCount)
|
||||
{
|
||||
#if NET5_0_OR_GREATER
|
||||
var span = CollectionsMarshal.AsSpan(tokens);
|
||||
return Eval(span, pastTokensCount);
|
||||
#else
|
||||
// on netstandard2.0 we can't use collections marshal to get directly at the internal memory of
|
||||
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
|
||||
// avoid the copying.
|
||||
|
||||
var rented = ArrayPool<llama_token>.Shared.Rent(tokens.Count);
|
||||
try
|
||||
{
|
||||
tokens.CopyTo(rented, 0);
|
||||
return Eval(rented, pastTokensCount);
|
||||
}
|
||||
finally
|
||||
{
|
||||
ArrayPool<llama_token>.Shared.Return(rented);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
///
|
||||
/// </summary>
|
||||
/// <param name="tokens"></param>
|
||||
/// <param name="pastTokensCount"></param>
|
||||
/// <returns>The updated `pastTokensCount`.</returns>
|
||||
/// <exception cref="RuntimeError"></exception>
|
||||
public int Eval(ReadOnlyMemory<llama_token> tokens, llama_token pastTokensCount)
|
||||
{
|
||||
return Eval(tokens.Span, pastTokensCount);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
///
|
||||
/// </summary>
|
||||
/// <param name="tokens"></param>
|
||||
/// <param name="pastTokensCount"></param>
|
||||
/// <returns>The updated `pastTokensCount`.</returns>
|
||||
/// <exception cref="RuntimeError"></exception>
|
||||
public int Eval(ReadOnlySpan<llama_token> tokens, llama_token pastTokensCount)
|
||||
{
|
||||
int total = tokens.Length;
|
||||
for(int i = 0; i < total; i += Params.BatchSize)
|
||||
|
@ -402,7 +458,7 @@ namespace LLama
|
|||
n_eval = Params.BatchSize;
|
||||
}
|
||||
|
||||
if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads))
|
||||
if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads))
|
||||
{
|
||||
_logger?.Log(nameof(LLamaContext), "Failed to eval.", ILLamaLogger.LogLevel.Error);
|
||||
throw new RuntimeError("Failed to eval.");
|
||||
|
@ -412,6 +468,7 @@ namespace LLama
|
|||
}
|
||||
return pastTokensCount;
|
||||
}
|
||||
#endregion
|
||||
|
||||
internal IEnumerable<string> GenerateResult(IEnumerable<llama_token> ids)
|
||||
{
|
||||
|
@ -419,6 +476,16 @@ namespace LLama
|
|||
yield return _ctx.TokenToString(id, _encoding);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Convert a token into a string
|
||||
/// </summary>
|
||||
/// <param name="token"></param>
|
||||
/// <returns></returns>
|
||||
public string TokenToString(llama_token token)
|
||||
{
|
||||
return NativeHandle.TokenToString(token, Encoding);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public virtual void Dispose()
|
||||
{
|
||||
|
|
|
@ -179,12 +179,14 @@ namespace LLama.Native
|
|||
/// <param name="n_past">the number of tokens to use from previous eval calls</param>
|
||||
/// <param name="n_threads"></param>
|
||||
/// <returns>Returns true on success</returns>
|
||||
public bool Eval(Memory<int> tokens, int n_past, int n_threads)
|
||||
public bool Eval(ReadOnlySpan<int> tokens, int n_past, int n_threads)
|
||||
{
|
||||
using var pin = tokens.Pin();
|
||||
unsafe
|
||||
{
|
||||
return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0;
|
||||
fixed (int* pinned = tokens)
|
||||
{
|
||||
return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ namespace LLama
|
|||
[Obsolete("Use SafeLLamaContextHandle Eval method instead")]
|
||||
public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
|
||||
{
|
||||
var slice = tokens.AsMemory().Slice(startIndex, n_tokens);
|
||||
var slice = tokens.AsSpan().Slice(startIndex, n_tokens);
|
||||
return ctx.Eval(slice, n_past, n_threads) ? 0 : 1;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue