using LLama.Abstractions;
using LLama.Common;
using LLama.Exceptions;
using LLama.Native;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
namespace LLama
{
using llama_token = Int32;
///
/// The base class for stateful LLama executors.
///
public abstract class StatefulExecutorBase : ILLamaExecutor
{
///
/// The logger used by this executor.
///
protected ILogger? _logger;
///
/// The tokens that were already processed by the model.
///
protected int _pastTokensCount; // n_past
///
/// The tokens that were consumed by the model during the current inference.
///
protected int _consumedTokensCount; // n_consume
///
///
///
protected int _n_session_consumed;
///
///
///
protected int _n_matching_session_tokens;
///
/// The path of the session file.
///
protected string? _pathSession;
///
/// A container of the tokens to be processed and after processed.
///
protected List _embeds = new(); // embd
///
/// A container for the tokens of input.
///
protected List _embed_inps = new();
///
///
///
protected List _session_tokens = new();
///
/// The last tokens generated by the model.
///
protected FixedSizeQueue _last_n_tokens;
///
/// The context used by the executor.
///
public LLamaContext Context { get; }
///
/// Current "mu" value for mirostat sampling
///
protected float? MirostatMu { get; set; }
private StreamingTokenDecoder _decoder;
///
///
///
///
///
protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
{
_logger = logger;
Context = context;
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
_last_n_tokens = new FixedSizeQueue(Context.ContextSize);
_decoder = new StreamingTokenDecoder(context);
}
///
/// This API is currently not verified.
///
///
///
///
///
public unsafe StatefulExecutorBase WithSessionFile(string filename)
{
_pathSession = filename;
if (string.IsNullOrEmpty(filename))
{
throw new ArgumentNullException(nameof(filename), "File name cannot be empty.");
}
if (File.Exists(filename))
{
_logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}");
llama_token[] session_tokens = new llama_token[Context.ContextSize];
ulong n_token_count_out = 0;
if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, &n_token_count_out))
{
_logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}");
throw new RuntimeError($"Failed to load session file {_pathSession}");
}
_session_tokens = session_tokens.Take((int)n_token_count_out).ToList();
_logger?.LogInformation($"[LLamaExecutor] Loaded a session with prompt size of {session_tokens.Length} tokens");
}
else
{
_logger?.LogWarning("[LLamaExecutor] Session file does not exist, will create");
}
_n_matching_session_tokens = 0;
if (_session_tokens.Count > 0)
{
foreach (var id in _session_tokens)
{
if (_n_matching_session_tokens >= _embed_inps.Count || id != _embed_inps[_n_matching_session_tokens])
{
break;
}
_n_matching_session_tokens++;
}
if (_n_matching_session_tokens >= _embed_inps.Count)
{
_logger?.LogInformation("[LLamaExecutor] Session file has exact match for prompt!");
}
else if (_n_matching_session_tokens < _embed_inps.Count / 2)
{
_logger?.LogWarning($"[LLamaExecutor] Session file has low similarity to prompt ({_n_matching_session_tokens} / {_embed_inps.Count} tokens) will mostly be reevaluated");
}
else
{
_logger?.LogInformation($"[LLamaExecutor] Session file matches {_n_matching_session_tokens} / {_embed_inps.Count} tokens of prompt");
}
}
return this;
}
///
/// This API has not been verified currently.
///
///
public void SaveSessionFile(string filename)
{
var session_token_array = _session_tokens.ToArray();
NativeApi.llama_save_session_file(Context.NativeHandle, filename, session_token_array, (ulong)session_token_array.Length);
}
///
/// After running out of the context, take some tokens from the original prompt and recompute the logits in batches.
///
///
protected virtual void HandleRunOutOfContext(int tokensToKeep)
{
// if we run out of context:
// - take the tokensToKeep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - tokensToKeep) tokens and recompute the logits in batches
int n_left = _pastTokensCount - tokensToKeep;
_pastTokensCount = Math.Max(1, tokensToKeep);
// insert n_left/2 tokens at the start of embed from last_n_tokens
_embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip(Context.ContextSize - n_left / 2 - _embeds.Count));
// stop saving session if we run out of context
_pathSession = string.Empty;
}
///
/// Try to reuse the matching prefix from the session file.
///
protected virtual void TryReuseMathingPrefix()
{
if (_n_session_consumed < _session_tokens.Count)
{
int i = 0;
for (; i < _embeds.Count; i++)
{
if (_embeds[i] != _session_tokens[_n_session_consumed])
{
_session_tokens = _session_tokens.Take(_n_session_consumed).ToList();
break;
}
_pastTokensCount++;
_n_session_consumed++;
if (_n_session_consumed >= _session_tokens.Count)
{
i++;
break;
}
}
if (i > 0)
{
_embeds.RemoveRange(0, i);
}
}
}
///
/// Decide whether to continue the loop.
///
///
///
protected abstract Task GetLoopCondition(InferStateArgs args);
///
/// Preprocess the inputs before the inference.
///
///
///
protected abstract Task PreprocessInputs(string text, InferStateArgs args);
///
/// Do some post processing after the inference.
///
///
///
///
protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
///
/// The core inference logic.
///
///
///
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
///
/// Save the current state to a file.
///
///
public abstract Task SaveState(string filename);
///
/// Get the current state data.
///
///
public abstract ExecutorBaseState GetStateData();
///
/// Load the state from data.
///
///
public abstract Task LoadState(ExecutorBaseState data);
///
/// Load the state from a file.
///
///
public abstract Task LoadState(string filename);
///
/// Execute the inference.
///
///
///
///
///
public virtual async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
inferenceParams ??= new InferenceParams();
var args = new InferStateArgs
{
Antiprompts = inferenceParams.AntiPrompts.ToList(),
RemainedTokens = inferenceParams.MaxTokens,
ReturnValue = false,
WaitForInput = false,
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
};
await PreprocessInputs(text, args);
while (await GetLoopCondition(args))
{
if (cancellationToken.IsCancellationRequested)
{
break;
}
await InferInternal(inferenceParams, args);
if (args.ReturnValue)
{
_decoder.AddRange(_embeds);
yield return _decoder.Read();
}
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
if (extraOutputs is { Count: > 0 })
{
foreach (var item in extraOutputs)
{
yield return item;
}
}
if (breakGeneration)
{
break;
}
}
}
///
/// State arguments that are used in single inference
///
protected class InferStateArgs
{
///
///
///
public IList? Antiprompts { get; set; }
///
/// Tokens count remained to be used. (n_remain)
///
public int RemainedTokens { get; set; }
///
///
///
public bool ReturnValue { get; set; }
///
///
///
public bool WaitForInput { get; set; }
///
///
///
public bool NeedToSaveSession { get; set; }
}
public class ExecutorBaseState
{
[JsonPropertyName("n_past")]
public int PastTokensCount { get; set; }
[JsonPropertyName("n_consumed")]
public int ConsumedTokensCount { get; set; }
[JsonPropertyName("n_session_consumed")]
public int ConsumedSessionCount { get; set; }
[JsonPropertyName("n_matching_session_tokens")]
public int MatchingSessionTokensCount { get; set; }
[JsonPropertyName("path_session")]
public string? SessionFilePath { get; set; }
[JsonPropertyName("embd")]
public List Embeds { get; set; }
[JsonPropertyName("embd_inps")]
public List EmbedInps { get; set; }
[JsonPropertyName("session_tokens")]
public List SessionTokens { get; set; }
[JsonPropertyName("last_n_tokens")]
public llama_token[] LastTokens { get; set; }
[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }
[JsonPropertyName("mirostat_mu")]
public float? MirostatMu { get; set; }
}
}
}