diff --git a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs index 65ac8d91..c3cae930 100644 --- a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion { public class ChatSessionStripRoleName { - public static void Run() + public static async Task Run() { Console.Write("Please input your model path: "); var modelPath = Console.ReadLine(); @@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion Console.Write(prompt); while (true) { - foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) + await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) { Console.Write(text); } diff --git a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs index dcbcc07b..70dbe757 100644 --- a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion { public class ChatSessionWithRoleName { - public static void Run() + public static async Task Run() { Console.Write("Please input your model path: "); var modelPath = Console.ReadLine(); @@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion Console.Write(prompt); while (true) { - foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) + await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) { Console.Write(text); } diff --git a/LLama.Examples/NewVersion/GrammarJsonResponse.cs b/LLama.Examples/NewVersion/GrammarJsonResponse.cs index a3c147f5..55405d14 100644 --- a/LLama.Examples/NewVersion/GrammarJsonResponse.cs +++ b/LLama.Examples/NewVersion/GrammarJsonResponse.cs @@ -5,9 +5,9 @@ namespace LLama.Examples.NewVersion { public class GrammarJsonResponse { - public static void Run() + public static async Task Run() { - var gbnf = File.ReadAllText("Assets/json.gbnf").Trim(); + var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim(); var grammar = Grammar.Parse(gbnf, "root"); Console.Write("Please input your model path: "); @@ -43,7 +43,7 @@ namespace LLama.Examples.NewVersion Console.ForegroundColor = ConsoleColor.White; Console.Write("Answer: "); prompt = $"Question: {prompt?.Trim()} Answer: "; - foreach (var text in ex.Infer(prompt, inferenceParams)) + await foreach (var text in ex.InferAsync(prompt, inferenceParams)) { Console.Write(text); } diff --git a/LLama.Examples/NewVersion/InstructModeExecute.cs b/LLama.Examples/NewVersion/InstructModeExecute.cs index b0e325f1..6a9912d9 100644 --- a/LLama.Examples/NewVersion/InstructModeExecute.cs +++ b/LLama.Examples/NewVersion/InstructModeExecute.cs @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion { public class InstructModeExecute { - public static void Run() + public static async Task Run() { Console.Write("Please input your model path: "); var modelPath = Console.ReadLine(); @@ -29,7 +29,7 @@ namespace LLama.Examples.NewVersion while (true) { - foreach (var text in executor.Infer(prompt, inferenceParams)) + await foreach (var text in executor.InferAsync(prompt, inferenceParams)) { Console.Write(text); } diff --git a/LLama.Examples/NewVersion/LoadAndSaveSession.cs b/LLama.Examples/NewVersion/LoadAndSaveSession.cs index 948ac6cd..33774b13 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveSession.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveSession.cs @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion { public class SaveAndLoadSession { - public static void Run() + public static async Task Run() { Console.Write("Please input your model path: "); var modelPath = Console.ReadLine(); @@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion Console.Write(prompt); while (true) { - foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) + await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) { Console.Write(text); } diff --git a/LLama.Examples/NewVersion/LoadAndSaveState.cs b/LLama.Examples/NewVersion/LoadAndSaveState.cs index e7e0d4ef..28ee30d6 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveState.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveState.cs @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion { public class LoadAndSaveState { - public static void Run() + public static async Task Run() { Console.Write("Please input your model path: "); var modelPath = Console.ReadLine(); @@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion while (true) { - foreach (var text in ex.Infer(prompt, inferenceParams)) + await foreach (var text in ex.InferAsync(prompt, inferenceParams)) { Console.Write(text); } diff --git a/LLama.Examples/NewVersion/StatelessModeExecute.cs b/LLama.Examples/NewVersion/StatelessModeExecute.cs index ddd6227f..7b75e373 100644 --- a/LLama.Examples/NewVersion/StatelessModeExecute.cs +++ b/LLama.Examples/NewVersion/StatelessModeExecute.cs @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion { public class StatelessModeExecute { - public static void Run() + public static async Task Run() { Console.Write("Please input your model path: "); var modelPath = Console.ReadLine(); @@ -35,7 +35,7 @@ namespace LLama.Examples.NewVersion Console.ForegroundColor = ConsoleColor.White; Console.Write("Answer: "); prompt = $"Question: {prompt?.Trim()} Answer: "; - foreach (var text in ex.Infer(prompt, inferenceParams)) + await foreach (var text in ex.InferAsync(prompt, inferenceParams)) { Console.Write(text); } diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs index 07f61422..5c6bf6f3 100644 --- a/LLama.Examples/NewVersion/TestRunner.cs +++ b/LLama.Examples/NewVersion/TestRunner.cs @@ -29,11 +29,11 @@ if (choice == 0) { - ChatSessionWithRoleName.Run(); + await ChatSessionWithRoleName.Run(); } else if (choice == 1) { - ChatSessionStripRoleName.Run(); + await ChatSessionStripRoleName.Run(); } else if(choice == 2) { @@ -41,19 +41,19 @@ } else if(choice == 3) { - InstructModeExecute.Run(); + await InstructModeExecute.Run(); } else if(choice == 4) { - StatelessModeExecute.Run(); + await StatelessModeExecute.Run(); } else if(choice == 5) { - SaveAndLoadSession.Run(); + await SaveAndLoadSession.Run(); } else if(choice == 6) { - LoadAndSaveState.Run(); + await LoadAndSaveState.Run(); } else if(choice == 7) { @@ -69,7 +69,7 @@ } else if (choice == 10) { - GrammarJsonResponse.Run(); + await GrammarJsonResponse.Run(); } else if (choice == 11) { diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs index b86a0f40..7bd012b8 100644 --- a/LLama.Unittest/GrammarTest.cs +++ b/LLama.Unittest/GrammarTest.cs @@ -41,7 +41,7 @@ namespace LLama.Unittest } [Fact] - public void SampleWithTrivialGrammar() + public async Task SampleWithTrivialGrammar() { // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so // we can be confident it's not what the LLM would say if not constrained by the grammar! @@ -66,7 +66,7 @@ namespace LLama.Unittest Grammar = grammar, }; - var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList(); + var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync(); Assert.Equal("cat", result[0]); } diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index a4680f6d..b6b3f0b0 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -12,6 +12,7 @@ + runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 1748e02d..9cfd3fca 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -27,15 +27,15 @@ namespace LLama.Unittest } [Fact] - public void Stateless() + public async Task Stateless() { var executor = new StatelessExecutor(_weights, _params); const string question = "Question. what is a cat?\nAnswer: "; var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; - var result1 = string.Join("", executor.Infer(question, @params)); - var result2 = string.Join("", executor.Infer(question, @params)); + var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); + var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); _testOutputHelper.WriteLine(result1); @@ -44,7 +44,7 @@ namespace LLama.Unittest } [Fact] - public void OutOfContext() + public async Task OutOfContext() { var executor = new StatelessExecutor(_weights, _params); @@ -58,8 +58,8 @@ namespace LLama.Unittest TokensKeep = question.Length, }; - var result1 = string.Join("", executor.Infer(question, @params)); - var result2 = string.Join("", executor.Infer(question, @params)); + var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); + var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); _testOutputHelper.WriteLine(result1); diff --git a/LLama.WebAPI/Controllers/ChatController.cs b/LLama.WebAPI/Controllers/ChatController.cs index 001a3224..9643ccf8 100644 --- a/LLama.WebAPI/Controllers/ChatController.cs +++ b/LLama.WebAPI/Controllers/ChatController.cs @@ -18,7 +18,7 @@ namespace LLama.WebAPI.Controllers } [HttpPost("Send")] - public string SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service) + public Task SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service) { return _service.Send(input); } diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs index a9ac3a44..ab542694 100644 --- a/LLama.WebAPI/Services/StatefulChatService.cs +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -28,7 +28,7 @@ public class StatefulChatService : IDisposable _context?.Dispose(); } - public string Send(SendMessageInput input) + public async Task Send(SendMessageInput input) { var userInput = input.Text; if (!_continue) @@ -42,13 +42,13 @@ public class StatefulChatService : IDisposable Console.Write(input.Text); Console.ForegroundColor = ConsoleColor.White; - var outputs = _session.Chat(userInput, new Common.InferenceParams() + var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() { RepeatPenalty = 1.0f, AntiPrompts = new string[] { "User:" }, }); var result = ""; - foreach (var output in outputs) + await foreach (var output in outputs) { Console.Write(output); result += output; diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs index a7af0243..ef5453a7 100644 --- a/LLama/Abstractions/ILLamaExecutor.cs +++ b/LLama/Abstractions/ILLamaExecutor.cs @@ -13,15 +13,6 @@ namespace LLama.Abstractions /// public LLamaContext Context { get; } - /// - /// Infers a response from the model. - /// - /// Your prompt - /// Any additional parameters - /// A cancellation token. - /// - IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default); - /// /// Asynchronously infers a response from the model. /// diff --git a/LLama/Abstractions/ITextStreamTransform.cs b/LLama/Abstractions/ITextStreamTransform.cs index e96febcf..2725214f 100644 --- a/LLama/Abstractions/ITextStreamTransform.cs +++ b/LLama/Abstractions/ITextStreamTransform.cs @@ -7,13 +7,6 @@ namespace LLama.Abstractions /// public interface ITextStreamTransform { - /// - /// Takes a stream of tokens and transforms them, returning a new stream of tokens. - /// - /// - /// - IEnumerable Transform(IEnumerable tokens); - /// /// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously. /// diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 5ed6a459..457e7e48 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -134,26 +134,6 @@ namespace LLama } } - /// - /// Get the response from the LLama model with chat histories. - /// - /// - /// - /// - /// - public IEnumerable Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) - { - var prompt = HistoryTransform.HistoryToText(history); - History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); - StringBuilder sb = new(); - foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken)) - { - yield return result; - sb.Append(result); - } - History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); - } - /// /// Get the response from the LLama model. Note that prompt could not only be the preset words, /// but also the question you want to ask. @@ -162,15 +142,14 @@ namespace LLama /// /// /// - public IEnumerable Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { foreach(var inputTransform in InputTransformPipeline) - { prompt = inputTransform.Transform(prompt); - } + History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); StringBuilder sb = new(); - foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken)) + await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) { yield return result; sb.Append(result); @@ -198,35 +177,6 @@ namespace LLama History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); } - /// - /// Get the response from the LLama model with chat histories asynchronously. - /// - /// - /// - /// - /// - public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - foreach (var inputTransform in InputTransformPipeline) - { - prompt = inputTransform.Transform(prompt); - } - History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); - StringBuilder sb = new(); - await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken)) - { - yield return result; - sb.Append(result); - } - History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); - } - - private IEnumerable ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) - { - var results = _executor.Infer(prompt, inferenceParams, cancellationToken); - return OutputTransform.Transform(results); - } - private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 49bd3170..df972e47 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -10,6 +10,7 @@ using System.Linq; using System.Runtime.CompilerServices; using System.Text.Json.Serialization; using System.Threading; +using System.Threading.Tasks; namespace LLama { @@ -212,47 +213,53 @@ namespace LLama /// /// /// - protected abstract bool GetLoopCondition(InferStateArgs args); + protected abstract Task GetLoopCondition(InferStateArgs args); + /// /// Preprocess the inputs before the inference. /// /// /// - protected abstract void PreprocessInputs(string text, InferStateArgs args); + protected abstract Task PreprocessInputs(string text, InferStateArgs args); + /// /// Do some post processing after the inference. /// /// /// - /// /// - protected abstract bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs); + protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args); + /// /// The core inference logic. /// /// /// - protected abstract void InferInternal(IInferenceParams inferenceParams, InferStateArgs args); + protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args); + /// /// Save the current state to a file. /// /// - public abstract void SaveState(string filename); + public abstract Task SaveState(string filename); + /// /// Get the current state data. /// /// public abstract ExecutorBaseState GetStateData(); + /// /// Load the state from data. /// /// - public abstract void LoadState(ExecutorBaseState data); + public abstract Task LoadState(ExecutorBaseState data); + /// /// Load the state from a file. /// /// - public abstract void LoadState(string filename); + public abstract Task LoadState(string filename); /// @@ -262,12 +269,12 @@ namespace LLama /// /// /// - public virtual IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) + public virtual async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); inferenceParams ??= new InferenceParams(); - InferStateArgs args = new InferStateArgs() + var args = new InferStateArgs { Antiprompts = inferenceParams.AntiPrompts.ToList(), RemainedTokens = inferenceParams.MaxTokens, @@ -276,15 +283,15 @@ namespace LLama NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count }; - PreprocessInputs(text, args); + await PreprocessInputs(text, args); - while (GetLoopCondition(args)) + while (await GetLoopCondition(args)) { if (cancellationToken.IsCancellationRequested) { break; } - InferInternal(inferenceParams, args); + await InferInternal(inferenceParams, args); if (args.ReturnValue) { @@ -292,8 +299,8 @@ namespace LLama yield return Context.TokenToString(id); } - var breakGeneration = PostProcess(inferenceParams, args, out var extraOutputs); - if (extraOutputs is not null) + var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args); + if (extraOutputs is { Count: > 0 }) { foreach (var item in extraOutputs) { @@ -307,21 +314,6 @@ namespace LLama } } - /// - /// Execute the inference asynchronously. - /// - /// - /// - /// - /// - public virtual async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - foreach (var result in Infer(text, inferenceParams, cancellationToken)) - { - yield return result; - } - } - /// /// State arguments that are used in single inference /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 2d46728f..6faa3db2 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -5,9 +5,9 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; -using System.Text; using System.Text.Json; using System.Text.Json.Serialization; +using System.Threading.Tasks; using LLama.Extensions; namespace LLama @@ -60,7 +60,7 @@ namespace LLama return state; } /// - public override void LoadState(ExecutorBaseState data) + public override Task LoadState(ExecutorBaseState data) { if(data is InstructExecutorState state) { @@ -81,34 +81,37 @@ namespace LLama { throw new ArgumentException("Invalid state data type."); } + + return Task.CompletedTask; } /// - public override void SaveState(string filename) + public override async Task SaveState(string filename) { var state = (InstructExecutorState)GetStateData(); using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { - JsonSerializer.Serialize(fs, state); + await JsonSerializer.SerializeAsync(fs, state); } } /// - public override void LoadState(string filename) + public override async Task LoadState(string filename) { using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) { - var state = JsonSerializer.Deserialize(fs); - LoadState(state); + var state = await JsonSerializer.DeserializeAsync(fs); + await LoadState(state); } } /// - protected override bool GetLoopCondition(InferStateArgs args) + protected override Task GetLoopCondition(InferStateArgs args) { - return args.RemainedTokens != 0 || _is_prompt_run; + return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run); } + /// - protected override void PreprocessInputs(string text, InferStateArgs args) + protected override Task PreprocessInputs(string text, InferStateArgs args) { args.Antiprompts ??= new List(); args.Antiprompts.Add(_instructionPrefix); @@ -133,23 +136,24 @@ namespace LLama args.RemainedTokens -= line_inp.Length; } + + return Task.CompletedTask; } + /// - protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs) + protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) { - extraOutputs = null; if (_embed_inps.Count <= _consumedTokensCount) { if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) { args.WaitForInput = true; - return true; + return (true, Array.Empty()); } if (_pastTokensCount > 0 && args.WaitForInput) { - extraOutputs = new[] { "\n> " }; - return true; + return (true, new[] { "\n> " }); } } @@ -163,10 +167,11 @@ namespace LLama args.RemainedTokens = inferenceParams.MaxTokens; args.WaitForInput = true; } - return false; + return (false, Array.Empty()); } + /// - protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { if (_embeds.Count > 0) { @@ -230,6 +235,8 @@ namespace LLama } } } + + return Task.CompletedTask; } /// /// The desciptor of the state of the instruct executor. diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 6b4c2104..ab403212 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -7,7 +7,7 @@ using System.IO; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; -using System.Text; +using System.Threading.Tasks; using LLama.Extensions; namespace LLama @@ -51,7 +51,7 @@ namespace LLama return state; } /// - public override void LoadState(ExecutorBaseState data) + public override Task LoadState(ExecutorBaseState data) { if (data is InteractiveExecutorState state) { @@ -68,23 +68,25 @@ namespace LLama } else throw new ArgumentException("Invalid state data type."); + + return Task.CompletedTask; } /// - public override void SaveState(string filename) + public override async Task SaveState(string filename) { - InteractiveExecutorState state = (InteractiveExecutorState)GetStateData(); - using(FileStream fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) + var state = (InteractiveExecutorState)GetStateData(); + using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { - JsonSerializer.Serialize(fs, state); + await JsonSerializer.SerializeAsync(fs, state); } } /// - public override void LoadState(string filename) + public override async Task LoadState(string filename) { - using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) + using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) { - var state = JsonSerializer.Deserialize(fs); - LoadState(state); + var state = await JsonSerializer.DeserializeAsync(fs); + await LoadState(state); } } @@ -92,13 +94,13 @@ namespace LLama /// Define whether to continue the loop to generate responses. /// /// - protected override bool GetLoopCondition(InferStateArgs args) + protected override Task GetLoopCondition(InferStateArgs args) { - return args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run; + return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run); } /// - protected override void PreprocessInputs(string text, InferStateArgs args) + protected override Task PreprocessInputs(string text, InferStateArgs args) { if (_is_prompt_run) { @@ -115,6 +117,8 @@ namespace LLama _embed_inps.AddRange(line_inp); args.RemainedTokens -= line_inp.Length; } + + return Task.CompletedTask; } /// @@ -122,24 +126,21 @@ namespace LLama /// /// /// - /// /// - protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs) + protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) { - extraOutputs = null; if (_embed_inps.Count <= _consumedTokensCount) { if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) args.WaitForInput = true; if (_pastTokensCount > 0 && args.WaitForInput) - return true; + return (true, Array.Empty()); } if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) { - extraOutputs = new[] { " [end of text]\n" }; - return true; + return (true, new[] { " [end of text]\n" }); } if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1) @@ -147,11 +148,12 @@ namespace LLama args.RemainedTokens = inferenceParams.MaxTokens; args.WaitForInput = true; } - return false; + + return (false, Array.Empty()); } /// - protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { if (_embeds.Count > 0) { diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 5c496037..5b1c4250 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -55,7 +55,7 @@ namespace LLama } /// - public IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { using var context = _weights.CreateContext(_params); Context = context; @@ -140,14 +140,5 @@ namespace LLama { return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding); } - - /// - public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - foreach (var result in Infer(text, inferenceParams, cancellationToken)) - { - yield return result; - } - } } }