LLamaSharp/LLama/LLamaInteractExecutor.cs

263 lines
9.8 KiB
C#
Raw Normal View History

using LLama.Common;
using LLama.Native;
2023-08-06 07:03:45 +08:00
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace LLama
{
using llama_token = Int32;
2023-06-20 02:38:57 +08:00
/// <summary>
/// The LLama executor for interactive mode.
/// </summary>
public class InteractiveExecutor : StatefulExecutorBase
{
bool _is_prompt_run = true;
llama_token[] _llama_token_newline;
2023-06-20 02:38:57 +08:00
/// <summary>
///
/// </summary>
/// <param name="model"></param>
2023-06-11 22:39:31 +08:00
public InteractiveExecutor(LLamaModel model) : base(model)
{
_llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding).ToArray();
}
2023-06-20 02:38:57 +08:00
/// <inheritdoc />
public override ExecutorBaseState GetStateData()
{
InteractiveExecutorState state = new()
{
ConsumedSessionCount = _n_session_consumed,
EmbedInps = _embed_inps,
IsPromptRun = _is_prompt_run,
ConsumedTokensCount = _consumedTokensCount,
Embeds = _embeds,
LastTokens = _last_n_tokens.ToArray(),
LLamaNewlineTokens = _llama_token_newline,
MatchingSessionTokensCount = _n_matching_session_tokens,
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity,
MirostateMu = MirostateMu
};
return state;
}
2023-06-20 02:38:57 +08:00
/// <inheritdoc />
public override void LoadState(ExecutorBaseState data)
{
if (data is InteractiveExecutorState state)
{
_n_session_consumed = state.ConsumedSessionCount;
_embed_inps = state.EmbedInps;
_is_prompt_run = state.IsPromptRun;
_consumedTokensCount = state.ConsumedTokensCount;
_embeds = state.Embeds;
2023-06-16 04:00:37 +08:00
_last_n_tokens = new FixedSizeQueue<llama_token>(state.LastTokensCapacity, state.LastTokens);
_llama_token_newline = state.LLamaNewlineTokens;
_n_matching_session_tokens = state.MatchingSessionTokensCount;
_pastTokensCount = state.PastTokensCount;
_pathSession = state.SessionFilePath;
_session_tokens = state.SessionTokens;
}
else
throw new ArgumentException("Invalid state data type.");
}
2023-06-20 02:38:57 +08:00
/// <inheritdoc />
public override void SaveState(string filename)
{
InteractiveExecutorState state = GetStateData() as InteractiveExecutorState;
using(FileStream fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write))
{
JsonSerializer.Serialize<InteractiveExecutorState>(fs, state);
}
}
2023-06-20 02:38:57 +08:00
/// <inheritdoc />
public override void LoadState(string filename)
{
using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = JsonSerializer.Deserialize<InteractiveExecutorState>(fs);
LoadState(state);
}
}
/// <summary>
2023-06-11 05:44:21 +08:00
/// Define whether to continue the loop to generate responses.
/// </summary>
/// <returns></returns>
2023-06-11 05:44:21 +08:00
protected override bool GetLoopCondition(InferStateArgs args)
{
return args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run;
}
2023-06-20 02:38:57 +08:00
/// <inheritdoc />
2023-06-11 05:44:21 +08:00
protected override void PreprocessInputs(string text, InferStateArgs args)
{
if (_is_prompt_run)
{
// When running the first input (prompt) in inteactive mode, we should specially process it.
text = " " + text;
_embed_inps = _model.Tokenize(text, true).ToList();
}
else
{
2023-06-11 05:44:21 +08:00
if (!text.EndsWith("\n"))
{
text += "\n";
}
var line_inp = _model.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Count();
}
2023-06-11 05:44:21 +08:00
}
2023-06-11 05:44:21 +08:00
/// <summary>
/// Return whether to break the generation.
/// </summary>
/// <param name="args"></param>
/// <returns></returns>
2023-08-06 07:03:45 +08:00
protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
2023-06-11 05:44:21 +08:00
{
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
{
2023-06-11 05:44:21 +08:00
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
{
2023-06-11 05:44:21 +08:00
string last_output = "";
foreach (var id in _last_n_tokens)
{
2023-06-11 05:44:21 +08:00
last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding);
}
2023-06-11 05:44:21 +08:00
foreach (var antiprompt in args.Antiprompts)
{
2023-06-11 05:44:21 +08:00
if (last_output.EndsWith(antiprompt))
{
args.WaitForInput = true;
break;
}
}
}
2023-06-11 05:44:21 +08:00
if (_pastTokensCount > 0 && args.WaitForInput)
{
2023-06-11 05:44:21 +08:00
return true;
}
}
2023-06-11 05:44:21 +08:00
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos())
{
extraOutputs = new string[] { " [end of text]\n" };
return true;
}
if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1)
2023-06-11 05:44:21 +08:00
{
args.RemainedTokens = inferenceParams.MaxTokens;
2023-06-11 05:44:21 +08:00
args.WaitForInput = true;
}
return false;
}
2023-06-20 02:38:57 +08:00
/// <inheritdoc />
2023-08-06 07:03:45 +08:00
protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
2023-06-11 05:44:21 +08:00
{
if (_embeds.Count > 0)
{
_is_prompt_run = false;
2023-06-11 05:44:21 +08:00
if (_pastTokensCount + _embeds.Count > _model.ContextSize)
{
HandleRunOutOfContext(inferenceParams.TokensKeep);
2023-06-11 05:44:21 +08:00
}
2023-06-11 05:44:21 +08:00
TryReuseMathingPrefix();
_pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount);
2023-06-11 05:44:21 +08:00
if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{
2023-06-11 05:44:21 +08:00
_session_tokens.AddRange(_embeds);
_n_session_consumed = _session_tokens.Count;
}
2023-06-11 05:44:21 +08:00
}
_embeds.Clear();
2023-06-11 05:44:21 +08:00
if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput)
{
var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount;
2023-06-11 05:44:21 +08:00
// optionally save the session on first sample (for faster prompt loading next time)
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{
2023-06-11 05:44:21 +08:00
args.NeedToSaveSession = false;
SaveSessionFile(_pathSession);
}
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
var mu = MirostateMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
MirostateMu = mu;
2023-06-11 05:44:21 +08:00
_last_n_tokens.Enqueue(id);
if (id == NativeApi.llama_token_eos())
{
id = _llama_token_newline.First();
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
{
2023-06-11 05:44:21 +08:00
var first_antiprompt = _model.Tokenize(args.Antiprompts[0], false);
_embed_inps.AddRange(first_antiprompt);
}
}
2023-06-11 05:44:21 +08:00
_embeds.Add(id);
2023-06-11 05:44:21 +08:00
args.RemainedTokens--;
args.ReturnValue = true;
}
else
{
while (_embed_inps.Count > _consumedTokensCount)
{
2023-06-11 05:44:21 +08:00
_embeds.Add(_embed_inps[_consumedTokensCount]);
_last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]);
_consumedTokensCount++;
if (_embeds.Count >= _model.Params.BatchSize)
{
break;
}
}
}
}
2023-06-20 02:38:57 +08:00
/// <summary>
/// The descriptor of the state of the interactive executor.
/// </summary>
public class InteractiveExecutorState : ExecutorBaseState
{
2023-06-20 02:38:57 +08:00
/// <summary>
/// Whether the executor is running for the first time (running the prompt).
/// </summary>
[JsonPropertyName("is_prompt_run")]
public bool IsPromptRun { get; set; }
2023-06-20 02:38:57 +08:00
/// <summary>
/// Tokens that represent a new line in with the current model.
/// </summary>
[JsonPropertyName("llama_token_newline")]
public llama_token[] LLamaNewlineTokens { get; set; }
}
}
}