2023-06-12 02:47:25 +08:00
|
|
|
|
using LLama.Common;
|
2023-06-10 18:37:58 +08:00
|
|
|
|
using LLama.Native;
|
2023-08-06 07:03:45 +08:00
|
|
|
|
using LLama.Abstractions;
|
2023-06-10 18:37:58 +08:00
|
|
|
|
using System;
|
|
|
|
|
using System.Collections.Generic;
|
2023-06-11 09:13:30 +08:00
|
|
|
|
using System.IO;
|
2023-06-10 18:37:58 +08:00
|
|
|
|
using System.Linq;
|
2023-06-11 09:13:30 +08:00
|
|
|
|
using System.Text.Json;
|
|
|
|
|
using System.Text.Json.Serialization;
|
2023-06-10 18:37:58 +08:00
|
|
|
|
|
|
|
|
|
namespace LLama
|
|
|
|
|
{
|
|
|
|
|
using llama_token = Int32;
|
2023-06-20 02:38:57 +08:00
|
|
|
|
/// <summary>
|
|
|
|
|
/// The LLama executor for interactive mode.
|
|
|
|
|
/// </summary>
|
2023-06-12 18:31:37 +08:00
|
|
|
|
public class InteractiveExecutor : StatefulExecutorBase
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 09:13:30 +08:00
|
|
|
|
bool _is_prompt_run = true;
|
|
|
|
|
llama_token[] _llama_token_newline;
|
2023-07-30 07:15:52 +08:00
|
|
|
|
|
2023-06-20 02:38:57 +08:00
|
|
|
|
/// <summary>
|
|
|
|
|
///
|
|
|
|
|
/// </summary>
|
|
|
|
|
/// <param name="model"></param>
|
2023-06-11 22:39:31 +08:00
|
|
|
|
public InteractiveExecutor(LLamaModel model) : base(model)
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 09:13:30 +08:00
|
|
|
|
_llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding).ToArray();
|
|
|
|
|
}
|
|
|
|
|
|
2023-06-20 02:38:57 +08:00
|
|
|
|
/// <inheritdoc />
|
2023-06-19 22:09:58 +08:00
|
|
|
|
public override ExecutorBaseState GetStateData()
|
2023-06-11 09:13:30 +08:00
|
|
|
|
{
|
|
|
|
|
InteractiveExecutorState state = new()
|
|
|
|
|
{
|
|
|
|
|
ConsumedSessionCount = _n_session_consumed,
|
|
|
|
|
EmbedInps = _embed_inps,
|
|
|
|
|
IsPromptRun = _is_prompt_run,
|
|
|
|
|
ConsumedTokensCount = _consumedTokensCount,
|
|
|
|
|
Embeds = _embeds,
|
|
|
|
|
LastTokens = _last_n_tokens.ToArray(),
|
|
|
|
|
LLamaNewlineTokens = _llama_token_newline,
|
|
|
|
|
MatchingSessionTokensCount = _n_matching_session_tokens,
|
|
|
|
|
PastTokensCount = _pastTokensCount,
|
|
|
|
|
SessionFilePath = _pathSession,
|
|
|
|
|
SessionTokens = _session_tokens,
|
2023-07-30 07:15:52 +08:00
|
|
|
|
LastTokensCapacity = _last_n_tokens.Capacity,
|
|
|
|
|
MirostateMu = MirostateMu
|
2023-06-11 09:13:30 +08:00
|
|
|
|
};
|
2023-06-19 22:09:58 +08:00
|
|
|
|
return state;
|
2023-06-11 09:13:30 +08:00
|
|
|
|
}
|
2023-06-20 02:38:57 +08:00
|
|
|
|
/// <inheritdoc />
|
2023-06-19 22:09:58 +08:00
|
|
|
|
public override void LoadState(ExecutorBaseState data)
|
2023-06-11 09:13:30 +08:00
|
|
|
|
{
|
2023-06-19 22:09:58 +08:00
|
|
|
|
if (data is InteractiveExecutorState state)
|
2023-06-11 09:13:30 +08:00
|
|
|
|
{
|
|
|
|
|
_n_session_consumed = state.ConsumedSessionCount;
|
|
|
|
|
_embed_inps = state.EmbedInps;
|
|
|
|
|
_is_prompt_run = state.IsPromptRun;
|
|
|
|
|
_consumedTokensCount = state.ConsumedTokensCount;
|
|
|
|
|
_embeds = state.Embeds;
|
2023-06-16 04:00:37 +08:00
|
|
|
|
_last_n_tokens = new FixedSizeQueue<llama_token>(state.LastTokensCapacity, state.LastTokens);
|
2023-06-11 09:13:30 +08:00
|
|
|
|
_llama_token_newline = state.LLamaNewlineTokens;
|
|
|
|
|
_n_matching_session_tokens = state.MatchingSessionTokensCount;
|
|
|
|
|
_pastTokensCount = state.PastTokensCount;
|
|
|
|
|
_pathSession = state.SessionFilePath;
|
|
|
|
|
_session_tokens = state.SessionTokens;
|
|
|
|
|
}
|
2023-06-19 22:09:58 +08:00
|
|
|
|
else
|
|
|
|
|
throw new ArgumentException("Invalid state data type.");
|
|
|
|
|
}
|
2023-06-20 02:38:57 +08:00
|
|
|
|
/// <inheritdoc />
|
2023-06-19 22:09:58 +08:00
|
|
|
|
public override void SaveState(string filename)
|
|
|
|
|
{
|
|
|
|
|
InteractiveExecutorState state = GetStateData() as InteractiveExecutorState;
|
|
|
|
|
using(FileStream fs = new FileStream(filename, FileMode.OpenOrCreate, FileAccess.Write))
|
|
|
|
|
{
|
|
|
|
|
JsonSerializer.Serialize<InteractiveExecutorState>(fs, state);
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-06-20 02:38:57 +08:00
|
|
|
|
/// <inheritdoc />
|
2023-06-19 22:09:58 +08:00
|
|
|
|
public override void LoadState(string filename)
|
|
|
|
|
{
|
|
|
|
|
using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
|
|
|
|
|
{
|
|
|
|
|
var state = JsonSerializer.Deserialize<InteractiveExecutorState>(fs);
|
|
|
|
|
LoadState(state);
|
|
|
|
|
}
|
2023-06-10 18:37:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// <summary>
|
2023-06-11 05:44:21 +08:00
|
|
|
|
/// Define whether to continue the loop to generate responses.
|
2023-06-10 18:37:58 +08:00
|
|
|
|
/// </summary>
|
|
|
|
|
/// <returns></returns>
|
2023-06-11 05:44:21 +08:00
|
|
|
|
protected override bool GetLoopCondition(InferStateArgs args)
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 09:13:30 +08:00
|
|
|
|
return args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run;
|
2023-06-10 18:37:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
2023-06-20 02:38:57 +08:00
|
|
|
|
/// <inheritdoc />
|
2023-06-11 05:44:21 +08:00
|
|
|
|
protected override void PreprocessInputs(string text, InferStateArgs args)
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 09:13:30 +08:00
|
|
|
|
if (_is_prompt_run)
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
|
|
|
|
// When running the first input (prompt) in inteactive mode, we should specially process it.
|
|
|
|
|
text = " " + text;
|
|
|
|
|
_embed_inps = _model.Tokenize(text, true).ToList();
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
2023-06-11 05:44:21 +08:00
|
|
|
|
if (!text.EndsWith("\n"))
|
|
|
|
|
{
|
|
|
|
|
text += "\n";
|
|
|
|
|
}
|
|
|
|
|
var line_inp = _model.Tokenize(text, false);
|
|
|
|
|
_embed_inps.AddRange(line_inp);
|
|
|
|
|
args.RemainedTokens -= line_inp.Count();
|
2023-06-10 18:37:58 +08:00
|
|
|
|
}
|
2023-06-11 05:44:21 +08:00
|
|
|
|
}
|
2023-06-10 18:37:58 +08:00
|
|
|
|
|
2023-06-11 05:44:21 +08:00
|
|
|
|
/// <summary>
|
|
|
|
|
/// Return whether to break the generation.
|
|
|
|
|
/// </summary>
|
|
|
|
|
/// <param name="args"></param>
|
|
|
|
|
/// <returns></returns>
|
2023-08-06 07:03:45 +08:00
|
|
|
|
protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
|
2023-06-11 05:44:21 +08:00
|
|
|
|
{
|
|
|
|
|
extraOutputs = null;
|
|
|
|
|
if (_embed_inps.Count <= _consumedTokensCount)
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 05:44:21 +08:00
|
|
|
|
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 05:44:21 +08:00
|
|
|
|
string last_output = "";
|
|
|
|
|
foreach (var id in _last_n_tokens)
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 05:44:21 +08:00
|
|
|
|
last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding);
|
2023-06-10 18:37:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
2023-06-11 05:44:21 +08:00
|
|
|
|
foreach (var antiprompt in args.Antiprompts)
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 05:44:21 +08:00
|
|
|
|
if (last_output.EndsWith(antiprompt))
|
|
|
|
|
{
|
|
|
|
|
args.WaitForInput = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
2023-06-10 18:37:58 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-06-11 05:44:21 +08:00
|
|
|
|
if (_pastTokensCount > 0 && args.WaitForInput)
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 05:44:21 +08:00
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-06-10 18:37:58 +08:00
|
|
|
|
|
2023-06-11 05:44:21 +08:00
|
|
|
|
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos())
|
|
|
|
|
{
|
|
|
|
|
extraOutputs = new string[] { " [end of text]\n" };
|
|
|
|
|
return true;
|
|
|
|
|
}
|
2023-06-10 18:37:58 +08:00
|
|
|
|
|
2023-06-12 02:47:25 +08:00
|
|
|
|
if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1)
|
2023-06-11 05:44:21 +08:00
|
|
|
|
{
|
2023-06-12 02:47:25 +08:00
|
|
|
|
args.RemainedTokens = inferenceParams.MaxTokens;
|
2023-06-11 05:44:21 +08:00
|
|
|
|
args.WaitForInput = true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
2023-06-10 18:37:58 +08:00
|
|
|
|
|
2023-06-20 02:38:57 +08:00
|
|
|
|
/// <inheritdoc />
|
2023-08-06 07:03:45 +08:00
|
|
|
|
protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
|
2023-06-11 05:44:21 +08:00
|
|
|
|
{
|
|
|
|
|
if (_embeds.Count > 0)
|
|
|
|
|
{
|
2023-06-11 09:13:30 +08:00
|
|
|
|
_is_prompt_run = false;
|
2023-06-11 05:44:21 +08:00
|
|
|
|
if (_pastTokensCount + _embeds.Count > _model.ContextSize)
|
|
|
|
|
{
|
2023-06-12 02:47:25 +08:00
|
|
|
|
HandleRunOutOfContext(inferenceParams.TokensKeep);
|
2023-06-11 05:44:21 +08:00
|
|
|
|
}
|
2023-06-10 18:37:58 +08:00
|
|
|
|
|
2023-06-11 05:44:21 +08:00
|
|
|
|
TryReuseMathingPrefix();
|
|
|
|
|
_pastTokensCount = _model.Eval(_embeds.ToArray(), _pastTokensCount);
|
2023-06-10 18:37:58 +08:00
|
|
|
|
|
2023-06-11 05:44:21 +08:00
|
|
|
|
if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 05:44:21 +08:00
|
|
|
|
_session_tokens.AddRange(_embeds);
|
|
|
|
|
_n_session_consumed = _session_tokens.Count;
|
2023-06-10 18:37:58 +08:00
|
|
|
|
}
|
2023-06-11 05:44:21 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_embeds.Clear();
|
2023-06-10 18:37:58 +08:00
|
|
|
|
|
2023-06-11 05:44:21 +08:00
|
|
|
|
if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput)
|
|
|
|
|
{
|
2023-06-12 02:47:25 +08:00
|
|
|
|
var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _model.ContextSize : inferenceParams.RepeatLastTokensCount;
|
2023-06-11 05:44:21 +08:00
|
|
|
|
|
|
|
|
|
// optionally save the session on first sample (for faster prompt loading next time)
|
|
|
|
|
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 05:44:21 +08:00
|
|
|
|
args.NeedToSaveSession = false;
|
|
|
|
|
SaveSessionFile(_pathSession);
|
2023-06-10 18:37:58 +08:00
|
|
|
|
}
|
|
|
|
|
|
2023-06-12 02:47:25 +08:00
|
|
|
|
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
|
|
|
|
|
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
|
2023-06-10 18:37:58 +08:00
|
|
|
|
|
2023-07-30 07:15:52 +08:00
|
|
|
|
var mu = MirostateMu;
|
|
|
|
|
var id = _model.Sample(
|
|
|
|
|
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
|
|
|
|
|
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
|
|
|
|
|
);
|
|
|
|
|
MirostateMu = mu;
|
2023-06-10 18:37:58 +08:00
|
|
|
|
|
2023-06-11 05:44:21 +08:00
|
|
|
|
_last_n_tokens.Enqueue(id);
|
|
|
|
|
|
|
|
|
|
if (id == NativeApi.llama_token_eos())
|
|
|
|
|
{
|
|
|
|
|
id = _llama_token_newline.First();
|
|
|
|
|
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 05:44:21 +08:00
|
|
|
|
var first_antiprompt = _model.Tokenize(args.Antiprompts[0], false);
|
|
|
|
|
_embed_inps.AddRange(first_antiprompt);
|
2023-06-10 18:37:58 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-06-11 05:44:21 +08:00
|
|
|
|
_embeds.Add(id);
|
2023-06-10 18:37:58 +08:00
|
|
|
|
|
2023-06-11 05:44:21 +08:00
|
|
|
|
args.RemainedTokens--;
|
|
|
|
|
args.ReturnValue = true;
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
while (_embed_inps.Count > _consumedTokensCount)
|
2023-06-10 18:37:58 +08:00
|
|
|
|
{
|
2023-06-11 05:44:21 +08:00
|
|
|
|
_embeds.Add(_embed_inps[_consumedTokensCount]);
|
|
|
|
|
_last_n_tokens.Enqueue(_embed_inps[_consumedTokensCount]);
|
|
|
|
|
_consumedTokensCount++;
|
|
|
|
|
if (_embeds.Count >= _model.Params.BatchSize)
|
|
|
|
|
{
|
|
|
|
|
break;
|
|
|
|
|
}
|
2023-06-10 18:37:58 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-06-11 09:13:30 +08:00
|
|
|
|
|
2023-06-20 02:38:57 +08:00
|
|
|
|
/// <summary>
|
|
|
|
|
/// The descriptor of the state of the interactive executor.
|
|
|
|
|
/// </summary>
|
2023-06-11 09:13:30 +08:00
|
|
|
|
public class InteractiveExecutorState : ExecutorBaseState
|
|
|
|
|
{
|
2023-06-20 02:38:57 +08:00
|
|
|
|
/// <summary>
|
|
|
|
|
/// Whether the executor is running for the first time (running the prompt).
|
|
|
|
|
/// </summary>
|
2023-06-11 09:13:30 +08:00
|
|
|
|
[JsonPropertyName("is_prompt_run")]
|
|
|
|
|
public bool IsPromptRun { get; set; }
|
2023-06-20 02:38:57 +08:00
|
|
|
|
/// <summary>
|
|
|
|
|
/// Tokens that represent a new line in with the current model.
|
|
|
|
|
/// </summary>
|
2023-06-11 09:13:30 +08:00
|
|
|
|
[JsonPropertyName("llama_token_newline")]
|
|
|
|
|
public llama_token[] LLamaNewlineTokens { get; set; }
|
|
|
|
|
}
|
2023-06-10 18:37:58 +08:00
|
|
|
|
}
|
|
|
|
|
}
|