Moved `Eval` out of `Utils` and into `SafeLLamaContextHandle`

This commit is contained in:
Martin Evans 2023-08-07 15:10:47 +01:00
parent 7fabcc1849
commit 2b2d3af26b
5 changed files with 34 additions and 10 deletions

View File

@ -335,7 +335,7 @@ namespace LLama
n_eval = Params.BatchSize;
}
if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0)
if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads))
{
_logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error);
throw new RuntimeError("Failed to eval.");

View File

@ -31,7 +31,7 @@ namespace LLama
_model = model;
var tokens = model.Tokenize(" ", true).ToArray();
Utils.Eval(_model.NativeHandle, tokens, 0, tokens.Length, 0, _model.Params.Threads);
_model.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _model.Params.Threads);
_originalState = model.GetState();
}
@ -52,7 +52,7 @@ namespace LLama
List<llama_token> tokens = _model.Tokenize(text, true).ToList();
int n_prompt_tokens = tokens.Count;
Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, n_prompt_tokens, n_past, _model.Params.Threads);
_model.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _model.Params.Threads);
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;

View File

@ -207,6 +207,17 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int n_tokens, int n_past, int n_threads);
/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
/// tokens + n_tokens is the provided batch of new tokens to process
/// n_past is the number of tokens to use from previous eval calls
/// </summary>
/// <param name="ctx"></param>
/// <param name="tokens"></param>
/// <param name="n_tokens"></param>
/// <param name="n_past"></param>
/// <param name="n_threads"></param>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName, EntryPoint = "llama_eval", CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval_with_pointer(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past, int n_threads);

View File

@ -169,5 +169,21 @@ namespace LLama.Native
{
return ThrowIfDisposed().TokenToSpan(token);
}
/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
/// </summary>
/// <param name="tokens">The provided batch of new tokens to process</param>
/// <param name="n_past">the number of tokens to use from previous eval calls</param>
/// <param name="n_threads"></param>
/// <returns>Returns true on success</returns>
public bool Eval(Memory<int> tokens, int n_past, int n_threads)
{
using var pin = tokens.Pin();
unsafe
{
return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0;
}
}
}
}

View File

@ -43,14 +43,11 @@ namespace LLama
return ctx.GetLogits();
}
public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
[Obsolete("Use SafeLLamaContextHandle Eval method instead")]
public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
{
int result;
fixed(llama_token* p = tokens)
{
result = NativeApi.llama_eval_with_pointer(ctx, p + startIndex, n_tokens, n_past, n_threads);
}
return result;
var slice = tokens.AsMemory().Slice(startIndex, n_tokens);
return ctx.Eval(slice, n_past, n_threads) ? 0 : 1;
}
[Obsolete("Use SafeLLamaContextHandle TokenToString method instead")]