Moved `Eval` out of `Utils` and into `SafeLLamaContextHandle`
This commit is contained in:
parent
7fabcc1849
commit
2b2d3af26b
|
@ -335,7 +335,7 @@ namespace LLama
|
|||
n_eval = Params.BatchSize;
|
||||
}
|
||||
|
||||
if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0)
|
||||
if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads))
|
||||
{
|
||||
_logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error);
|
||||
throw new RuntimeError("Failed to eval.");
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace LLama
|
|||
_model = model;
|
||||
|
||||
var tokens = model.Tokenize(" ", true).ToArray();
|
||||
Utils.Eval(_model.NativeHandle, tokens, 0, tokens.Length, 0, _model.Params.Threads);
|
||||
_model.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _model.Params.Threads);
|
||||
_originalState = model.GetState();
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,7 @@ namespace LLama
|
|||
List<llama_token> tokens = _model.Tokenize(text, true).ToList();
|
||||
int n_prompt_tokens = tokens.Count;
|
||||
|
||||
Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, n_prompt_tokens, n_past, _model.Params.Threads);
|
||||
_model.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _model.Params.Threads);
|
||||
|
||||
lastTokens.AddRange(tokens);
|
||||
n_past += n_prompt_tokens;
|
||||
|
|
|
@ -207,6 +207,17 @@ namespace LLama.Native
|
|||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int n_tokens, int n_past, int n_threads);
|
||||
|
||||
/// <summary>
|
||||
/// Run the llama inference to obtain the logits and probabilities for the next token.
|
||||
/// tokens + n_tokens is the provided batch of new tokens to process
|
||||
/// n_past is the number of tokens to use from previous eval calls
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="tokens"></param>
|
||||
/// <param name="n_tokens"></param>
|
||||
/// <param name="n_past"></param>
|
||||
/// <param name="n_threads"></param>
|
||||
/// <returns>Returns 0 on success</returns>
|
||||
[DllImport(libraryName, EntryPoint = "llama_eval", CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern int llama_eval_with_pointer(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past, int n_threads);
|
||||
|
||||
|
|
|
@ -169,5 +169,21 @@ namespace LLama.Native
|
|||
{
|
||||
return ThrowIfDisposed().TokenToSpan(token);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Run the llama inference to obtain the logits and probabilities for the next token.
|
||||
/// </summary>
|
||||
/// <param name="tokens">The provided batch of new tokens to process</param>
|
||||
/// <param name="n_past">the number of tokens to use from previous eval calls</param>
|
||||
/// <param name="n_threads"></param>
|
||||
/// <returns>Returns true on success</returns>
|
||||
public bool Eval(Memory<int> tokens, int n_past, int n_threads)
|
||||
{
|
||||
using var pin = tokens.Pin();
|
||||
unsafe
|
||||
{
|
||||
return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,14 +43,11 @@ namespace LLama
|
|||
return ctx.GetLogits();
|
||||
}
|
||||
|
||||
public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
|
||||
[Obsolete("Use SafeLLamaContextHandle Eval method instead")]
|
||||
public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
|
||||
{
|
||||
int result;
|
||||
fixed(llama_token* p = tokens)
|
||||
{
|
||||
result = NativeApi.llama_eval_with_pointer(ctx, p + startIndex, n_tokens, n_past, n_threads);
|
||||
}
|
||||
return result;
|
||||
var slice = tokens.AsMemory().Slice(startIndex, n_tokens);
|
||||
return ctx.Eval(slice, n_past, n_threads) ? 0 : 1;
|
||||
}
|
||||
|
||||
[Obsolete("Use SafeLLamaContextHandle TokenToString method instead")]
|
||||
|
|
Loading…
Reference in New Issue