LLamaSharp/LLama/LLamaExecutorBase.cs

378 lines
14 KiB
C#

using LLama.Abstractions;
using LLama.Common;
using LLama.Exceptions;
using LLama.Native;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
namespace LLama
{
using llama_token = Int32;
/// <summary>
/// The base class for stateful LLama executors.
/// </summary>
public abstract class StatefulExecutorBase : ILLamaExecutor
{
/// <summary>
/// The logger used by this executor.
/// </summary>
protected ILogger? _logger;
/// <summary>
/// The tokens that were already processed by the model.
/// </summary>
protected int _pastTokensCount; // n_past
/// <summary>
/// The tokens that were consumed by the model during the current inference.
/// </summary>
protected int _consumedTokensCount; // n_consume
/// <summary>
///
/// </summary>
protected int _n_session_consumed;
/// <summary>
///
/// </summary>
protected int _n_matching_session_tokens;
/// <summary>
/// The path of the session file.
/// </summary>
protected string? _pathSession;
/// <summary>
/// A container of the tokens to be processed and after processed.
/// </summary>
protected List<llama_token> _embeds = new(); // embd
/// <summary>
/// A container for the tokens of input.
/// </summary>
protected List<llama_token> _embed_inps = new();
/// <summary>
///
/// </summary>
protected List<llama_token> _session_tokens = new();
/// <summary>
/// The last tokens generated by the model.
/// </summary>
protected FixedSizeQueue<llama_token> _last_n_tokens;
/// <summary>
/// The context used by the executor.
/// </summary>
public LLamaContext Context { get; }
/// <summary>
/// Current "mu" value for mirostat sampling
/// </summary>
protected float? MirostatMu { get; set; }
/// <summary>
///
/// </summary>
/// <param name="context"></param>
/// <param name="logger"></param>
protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
{
_logger = logger;
Context = context;
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
_last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize).FillWith(0);
}
/// <summary>
/// This API is currently not verified.
/// </summary>
/// <param name="filename"></param>
/// <returns></returns>
/// <exception cref="ArgumentNullException"></exception>
/// <exception cref="RuntimeError"></exception>
public unsafe StatefulExecutorBase WithSessionFile(string filename)
{
_pathSession = filename;
if (string.IsNullOrEmpty(filename))
{
throw new ArgumentNullException(nameof(filename), "File name cannot be empty.");
}
if (File.Exists(filename))
{
_logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}");
llama_token[] session_tokens = new llama_token[Context.ContextSize];
ulong n_token_count_out = 0;
if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, &n_token_count_out))
{
_logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}");
throw new RuntimeError($"Failed to load session file {_pathSession}");
}
_session_tokens = session_tokens.Take((int)n_token_count_out).ToList();
_logger?.LogInformation($"[LLamaExecutor] Loaded a session with prompt size of {session_tokens.Length} tokens");
}
else
{
_logger?.LogWarning("[LLamaExecutor] Session file does not exist, will create");
}
_n_matching_session_tokens = 0;
if (_session_tokens.Count > 0)
{
foreach (var id in _session_tokens)
{
if (_n_matching_session_tokens >= _embed_inps.Count || id != _embed_inps[_n_matching_session_tokens])
{
break;
}
_n_matching_session_tokens++;
}
if (_n_matching_session_tokens >= _embed_inps.Count)
{
_logger?.LogInformation("[LLamaExecutor] Session file has exact match for prompt!");
}
else if (_n_matching_session_tokens < _embed_inps.Count / 2)
{
_logger?.LogWarning($"[LLamaExecutor] Session file has low similarity to prompt ({_n_matching_session_tokens} / {_embed_inps.Count} tokens) will mostly be reevaluated");
}
else
{
_logger?.LogInformation($"[LLamaExecutor] Session file matches {_n_matching_session_tokens} / {_embed_inps.Count} tokens of prompt");
}
}
return this;
}
/// <summary>
/// This API has not been verified currently.
/// </summary>
/// <param name="filename"></param>
public void SaveSessionFile(string filename)
{
var session_token_array = _session_tokens.ToArray();
NativeApi.llama_save_session_file(Context.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length);
}
/// <summary>
/// After running out of the context, take some tokens from the original prompt and recompute the logits in batches.
/// </summary>
/// <param name="tokensToKeep"></param>
protected virtual void HandleRunOutOfContext(int tokensToKeep)
{
// if we run out of context:
// - take the tokensToKeep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches
int n_left = _pastTokensCount - tokensToKeep;
_pastTokensCount = Math.Max(1, tokensToKeep);
// insert n_left/2 tokens at the start of embed from last_n_tokens
_embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip(Context.ContextSize - n_left / 2 - _embeds.Count));
// stop saving session if we run out of context
_pathSession = string.Empty;
}
/// <summary>
/// Try to reuse the matching prefix from the session file.
/// </summary>
protected virtual void TryReuseMathingPrefix()
{
if (_n_session_consumed < _session_tokens.Count)
{
int i = 0;
for (; i < _embeds.Count; i++)
{
if (_embeds[i] != _session_tokens[_n_session_consumed])
{
_session_tokens = _session_tokens.Take(_n_session_consumed).ToList();
break;
}
_pastTokensCount++;
_n_session_consumed++;
if (_n_session_consumed >= _session_tokens.Count)
{
i++;
break;
}
}
if (i > 0)
{
_embeds.RemoveRange(0, i);
}
}
}
/// <summary>
/// Decide whether to continue the loop.
/// </summary>
/// <param name="args"></param>
/// <returns></returns>
protected abstract Task<bool> GetLoopCondition(InferStateArgs args);
/// <summary>
/// Preprocess the inputs before the inference.
/// </summary>
/// <param name="text"></param>
/// <param name="args"></param>
protected abstract Task PreprocessInputs(string text, InferStateArgs args);
/// <summary>
/// Do some post processing after the inference.
/// </summary>
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <returns></returns>
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
/// <summary>
/// The core inference logic.
/// </summary>
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
/// <summary>
/// Save the current state to a file.
/// </summary>
/// <param name="filename"></param>
public abstract Task SaveState(string filename);
/// <summary>
/// Get the current state data.
/// </summary>
/// <returns></returns>
public abstract ExecutorBaseState GetStateData();
/// <summary>
/// Load the state from data.
/// </summary>
/// <param name="data"></param>
public abstract Task LoadState(ExecutorBaseState data);
/// <summary>
/// Load the state from a file.
/// </summary>
/// <param name="filename"></param>
public abstract Task LoadState(string filename);
/// <summary>
/// Execute the inference.
/// </summary>
/// <param name="text"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
inferenceParams ??= new InferenceParams();
var args = new InferStateArgs
{
Antiprompts = inferenceParams.AntiPrompts.ToList(),
RemainedTokens = inferenceParams.MaxTokens,
ReturnValue = false,
WaitForInput = false,
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
};
await PreprocessInputs(text, args);
while (await GetLoopCondition(args))
{
if (cancellationToken.IsCancellationRequested)
{
break;
}
await InferInternal(inferenceParams, args);
if (args.ReturnValue)
yield return Context.DeTokenize(_embeds);
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
if (extraOutputs is { Count: > 0 })
{
foreach (var item in extraOutputs)
{
yield return item;
}
}
if (breakGeneration)
{
break;
}
}
}
/// <summary>
/// State arguments that are used in single inference
/// </summary>
protected class InferStateArgs
{
/// <summary>
///
/// </summary>
public IList<string>? Antiprompts { get; set; }
/// <summary>
/// Tokens count remained to be used. (n_remain)
/// </summary>
public int RemainedTokens { get; set; }
/// <summary>
///
/// </summary>
public bool ReturnValue { get; set; }
/// <summary>
///
/// </summary>
public bool WaitForInput { get; set; }
/// <summary>
///
/// </summary>
public bool NeedToSaveSession { get; set; }
}
public class ExecutorBaseState
{
[JsonPropertyName("n_past")]
public int PastTokensCount { get; set; }
[JsonPropertyName("n_consumed")]
public int ConsumedTokensCount { get; set; }
[JsonPropertyName("n_session_consumed")]
public int ConsumedSessionCount { get; set; }
[JsonPropertyName("n_matching_session_tokens")]
public int MatchingSessionTokensCount { get; set; }
[JsonPropertyName("path_session")]
public string? SessionFilePath { get; set; }
[JsonPropertyName("embd")]
public List<llama_token> Embeds { get; set; }
[JsonPropertyName("embd_inps")]
public List<llama_token> EmbedInps { get; set; }
[JsonPropertyName("session_tokens")]
public List<llama_token> SessionTokens { get; set; }
[JsonPropertyName("last_n_tokens")]
public llama_token[] LastTokens { get; set; }
[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }
[JsonPropertyName("mirostat_mu")]
public float? MirostatMu { get; set; }
}
}
}