Merge pull request #173 from martindevans/async_only

Remove non-async
This commit is contained in:
Haiping 2023-09-17 10:19:16 -05:00 committed by GitHub
commit 9f9903c711
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 108 additions and 181 deletions

View File

@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{ {
public class ChatSessionStripRoleName public class ChatSessionStripRoleName
{ {
public static void Run() public static async Task Run()
{ {
Console.Write("Please input your model path: "); Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine(); var modelPath = Console.ReadLine();
@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
Console.Write(prompt); Console.Write(prompt);
while (true) while (true)
{ {
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{ {
Console.Write(text); Console.Write(text);
} }

View File

@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{ {
public class ChatSessionWithRoleName public class ChatSessionWithRoleName
{ {
public static void Run() public static async Task Run()
{ {
Console.Write("Please input your model path: "); Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine(); var modelPath = Console.ReadLine();
@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
Console.Write(prompt); Console.Write(prompt);
while (true) while (true)
{ {
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{ {
Console.Write(text); Console.Write(text);
} }

View File

@ -5,9 +5,9 @@ namespace LLama.Examples.NewVersion
{ {
public class GrammarJsonResponse public class GrammarJsonResponse
{ {
public static void Run() public static async Task Run()
{ {
var gbnf = File.ReadAllText("Assets/json.gbnf").Trim(); var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim();
var grammar = Grammar.Parse(gbnf, "root"); var grammar = Grammar.Parse(gbnf, "root");
Console.Write("Please input your model path: "); Console.Write("Please input your model path: ");
@ -43,7 +43,7 @@ namespace LLama.Examples.NewVersion
Console.ForegroundColor = ConsoleColor.White; Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: "); Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: "; prompt = $"Question: {prompt?.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams)) await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{ {
Console.Write(text); Console.Write(text);
} }

View File

@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{ {
public class InstructModeExecute public class InstructModeExecute
{ {
public static void Run() public static async Task Run()
{ {
Console.Write("Please input your model path: "); Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine(); var modelPath = Console.ReadLine();
@ -29,7 +29,7 @@ namespace LLama.Examples.NewVersion
while (true) while (true)
{ {
foreach (var text in executor.Infer(prompt, inferenceParams)) await foreach (var text in executor.InferAsync(prompt, inferenceParams))
{ {
Console.Write(text); Console.Write(text);
} }

View File

@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{ {
public class SaveAndLoadSession public class SaveAndLoadSession
{ {
public static void Run() public static async Task Run()
{ {
Console.Write("Please input your model path: "); Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine(); var modelPath = Console.ReadLine();
@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
Console.Write(prompt); Console.Write(prompt);
while (true) while (true)
{ {
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } })) await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{ {
Console.Write(text); Console.Write(text);
} }

View File

@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{ {
public class LoadAndSaveState public class LoadAndSaveState
{ {
public static void Run() public static async Task Run()
{ {
Console.Write("Please input your model path: "); Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine(); var modelPath = Console.ReadLine();
@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
while (true) while (true)
{ {
foreach (var text in ex.Infer(prompt, inferenceParams)) await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{ {
Console.Write(text); Console.Write(text);
} }

View File

@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{ {
public class StatelessModeExecute public class StatelessModeExecute
{ {
public static void Run() public static async Task Run()
{ {
Console.Write("Please input your model path: "); Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine(); var modelPath = Console.ReadLine();
@ -35,7 +35,7 @@ namespace LLama.Examples.NewVersion
Console.ForegroundColor = ConsoleColor.White; Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: "); Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: "; prompt = $"Question: {prompt?.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams)) await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{ {
Console.Write(text); Console.Write(text);
} }

View File

@ -30,11 +30,11 @@
if (choice == 0) if (choice == 0)
{ {
ChatSessionWithRoleName.Run(); await ChatSessionWithRoleName.Run();
} }
else if (choice == 1) else if (choice == 1)
{ {
ChatSessionStripRoleName.Run(); await ChatSessionStripRoleName.Run();
} }
else if(choice == 2) else if(choice == 2)
{ {
@ -42,19 +42,19 @@
} }
else if(choice == 3) else if(choice == 3)
{ {
InstructModeExecute.Run(); await InstructModeExecute.Run();
} }
else if(choice == 4) else if(choice == 4)
{ {
StatelessModeExecute.Run(); await StatelessModeExecute.Run();
} }
else if(choice == 5) else if(choice == 5)
{ {
SaveAndLoadSession.Run(); await SaveAndLoadSession.Run();
} }
else if(choice == 6) else if(choice == 6)
{ {
LoadAndSaveState.Run(); await LoadAndSaveState.Run();
} }
else if(choice == 7) else if(choice == 7)
{ {
@ -70,7 +70,7 @@
} }
else if (choice == 10) else if (choice == 10)
{ {
GrammarJsonResponse.Run(); await GrammarJsonResponse.Run();
} }
else if (choice == 11) else if (choice == 11)
{ {

View File

@ -41,7 +41,7 @@ namespace LLama.Unittest
} }
[Fact] [Fact]
public void SampleWithTrivialGrammar() public async Task SampleWithTrivialGrammar()
{ {
// Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so // Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so
// we can be confident it's not what the LLM would say if not constrained by the grammar! // we can be confident it's not what the LLM would say if not constrained by the grammar!
@ -66,7 +66,7 @@ namespace LLama.Unittest
Grammar = grammar, Grammar = grammar,
}; };
var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList(); var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync();
Assert.Equal("cat", result[0]); Assert.Equal("cat", result[0]);
} }

View File

@ -12,6 +12,7 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" /> <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
<PackageReference Include="xunit" Version="2.5.0" /> <PackageReference Include="xunit" Version="2.5.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.0"> <PackageReference Include="xunit.runner.visualstudio" Version="2.5.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>

View File

@ -27,15 +27,15 @@ namespace LLama.Unittest
} }
[Fact] [Fact]
public void Stateless() public async Task Stateless()
{ {
var executor = new StatelessExecutor(_weights, _params); var executor = new StatelessExecutor(_weights, _params);
const string question = "Question. what is a cat?\nAnswer: "; const string question = "Question. what is a cat?\nAnswer: ";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };
var result1 = string.Join("", executor.Infer(question, @params)); var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
var result2 = string.Join("", executor.Infer(question, @params)); var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
_testOutputHelper.WriteLine(result1); _testOutputHelper.WriteLine(result1);
@ -44,7 +44,7 @@ namespace LLama.Unittest
} }
[Fact] [Fact]
public void OutOfContext() public async Task OutOfContext()
{ {
var executor = new StatelessExecutor(_weights, _params); var executor = new StatelessExecutor(_weights, _params);
@ -58,8 +58,8 @@ namespace LLama.Unittest
TokensKeep = question.Length, TokensKeep = question.Length,
}; };
var result1 = string.Join("", executor.Infer(question, @params)); var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
var result2 = string.Join("", executor.Infer(question, @params)); var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
_testOutputHelper.WriteLine(result1); _testOutputHelper.WriteLine(result1);

View File

@ -18,7 +18,7 @@ namespace LLama.WebAPI.Controllers
} }
[HttpPost("Send")] [HttpPost("Send")]
public string SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service) public Task<string> SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service)
{ {
return _service.Send(input); return _service.Send(input);
} }

View File

@ -28,7 +28,7 @@ public class StatefulChatService : IDisposable
_context?.Dispose(); _context?.Dispose();
} }
public string Send(SendMessageInput input) public async Task<string> Send(SendMessageInput input)
{ {
var userInput = input.Text; var userInput = input.Text;
if (!_continue) if (!_continue)
@ -42,13 +42,13 @@ public class StatefulChatService : IDisposable
Console.Write(input.Text); Console.Write(input.Text);
Console.ForegroundColor = ConsoleColor.White; Console.ForegroundColor = ConsoleColor.White;
var outputs = _session.Chat(userInput, new Common.InferenceParams() var outputs = _session.ChatAsync(userInput, new Common.InferenceParams()
{ {
RepeatPenalty = 1.0f, RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" }, AntiPrompts = new string[] { "User:" },
}); });
var result = ""; var result = "";
foreach (var output in outputs) await foreach (var output in outputs)
{ {
Console.Write(output); Console.Write(output);
result += output; result += output;

View File

@ -13,15 +13,6 @@ namespace LLama.Abstractions
/// </summary> /// </summary>
public LLamaContext Context { get; } public LLamaContext Context { get; }
/// <summary>
/// Infers a response from the model.
/// </summary>
/// <param name="text">Your prompt</param>
/// <param name="inferenceParams">Any additional parameters</param>
/// <param name="token">A cancellation token.</param>
/// <returns></returns>
IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
/// <summary> /// <summary>
/// Asynchronously infers a response from the model. /// Asynchronously infers a response from the model.
/// </summary> /// </summary>

View File

@ -7,13 +7,6 @@ namespace LLama.Abstractions
/// </summary> /// </summary>
public interface ITextStreamTransform public interface ITextStreamTransform
{ {
/// <summary>
/// Takes a stream of tokens and transforms them, returning a new stream of tokens.
/// </summary>
/// <param name="tokens"></param>
/// <returns></returns>
IEnumerable<string> Transform(IEnumerable<string> tokens);
/// <summary> /// <summary>
/// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously. /// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously.
/// </summary> /// </summary>

View File

@ -134,26 +134,6 @@ namespace LLama
} }
} }
/// <summary>
/// Get the response from the LLama model with chat histories.
/// </summary>
/// <param name="history"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public IEnumerable<string> Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
var prompt = HistoryTransform.HistoryToText(history);
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new();
foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}
/// <summary> /// <summary>
/// Get the response from the LLama model. Note that prompt could not only be the preset words, /// Get the response from the LLama model. Note that prompt could not only be the preset words,
/// but also the question you want to ask. /// but also the question you want to ask.
@ -162,15 +142,14 @@ namespace LLama
/// <param name="inferenceParams"></param> /// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param> /// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
public IEnumerable<string> Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{ {
foreach(var inputTransform in InputTransformPipeline) foreach(var inputTransform in InputTransformPipeline)
{
prompt = inputTransform.Transform(prompt); prompt = inputTransform.Transform(prompt);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages); History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new(); StringBuilder sb = new();
foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken)) await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{ {
yield return result; yield return result;
sb.Append(result); sb.Append(result);
@ -198,35 +177,6 @@ namespace LLama
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages); History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
} }
/// <summary>
/// Get the response from the LLama model with chat histories asynchronously.
/// </summary>
/// <param name="prompt"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var inputTransform in InputTransformPipeline)
{
prompt = inputTransform.Transform(prompt);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new();
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}
private IEnumerable<string> ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
var results = _executor.Infer(prompt, inferenceParams, cancellationToken);
return OutputTransform.Transform(results);
}
private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{ {
var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken); var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);

View File

@ -10,6 +10,7 @@ using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Threading; using System.Threading;
using System.Threading.Tasks;
namespace LLama namespace LLama
{ {
@ -212,47 +213,53 @@ namespace LLama
/// </summary> /// </summary>
/// <param name="args"></param> /// <param name="args"></param>
/// <returns></returns> /// <returns></returns>
protected abstract bool GetLoopCondition(InferStateArgs args); protected abstract Task<bool> GetLoopCondition(InferStateArgs args);
/// <summary> /// <summary>
/// Preprocess the inputs before the inference. /// Preprocess the inputs before the inference.
/// </summary> /// </summary>
/// <param name="text"></param> /// <param name="text"></param>
/// <param name="args"></param> /// <param name="args"></param>
protected abstract void PreprocessInputs(string text, InferStateArgs args); protected abstract Task PreprocessInputs(string text, InferStateArgs args);
/// <summary> /// <summary>
/// Do some post processing after the inference. /// Do some post processing after the inference.
/// </summary> /// </summary>
/// <param name="inferenceParams"></param> /// <param name="inferenceParams"></param>
/// <param name="args"></param> /// <param name="args"></param>
/// <param name="extraOutputs"></param>
/// <returns></returns> /// <returns></returns>
protected abstract bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs); protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
/// <summary> /// <summary>
/// The core inference logic. /// The core inference logic.
/// </summary> /// </summary>
/// <param name="inferenceParams"></param> /// <param name="inferenceParams"></param>
/// <param name="args"></param> /// <param name="args"></param>
protected abstract void InferInternal(IInferenceParams inferenceParams, InferStateArgs args); protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
/// <summary> /// <summary>
/// Save the current state to a file. /// Save the current state to a file.
/// </summary> /// </summary>
/// <param name="filename"></param> /// <param name="filename"></param>
public abstract void SaveState(string filename); public abstract Task SaveState(string filename);
/// <summary> /// <summary>
/// Get the current state data. /// Get the current state data.
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
public abstract ExecutorBaseState GetStateData(); public abstract ExecutorBaseState GetStateData();
/// <summary> /// <summary>
/// Load the state from data. /// Load the state from data.
/// </summary> /// </summary>
/// <param name="data"></param> /// <param name="data"></param>
public abstract void LoadState(ExecutorBaseState data); public abstract Task LoadState(ExecutorBaseState data);
/// <summary> /// <summary>
/// Load the state from a file. /// Load the state from a file.
/// </summary> /// </summary>
/// <param name="filename"></param> /// <param name="filename"></param>
public abstract void LoadState(string filename); public abstract Task LoadState(string filename);
/// <summary> /// <summary>
@ -262,12 +269,12 @@ namespace LLama
/// <param name="inferenceParams"></param> /// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param> /// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
public virtual IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{ {
cancellationToken.ThrowIfCancellationRequested(); cancellationToken.ThrowIfCancellationRequested();
inferenceParams ??= new InferenceParams(); inferenceParams ??= new InferenceParams();
InferStateArgs args = new InferStateArgs() var args = new InferStateArgs
{ {
Antiprompts = inferenceParams.AntiPrompts.ToList(), Antiprompts = inferenceParams.AntiPrompts.ToList(),
RemainedTokens = inferenceParams.MaxTokens, RemainedTokens = inferenceParams.MaxTokens,
@ -276,15 +283,15 @@ namespace LLama
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
}; };
PreprocessInputs(text, args); await PreprocessInputs(text, args);
while (GetLoopCondition(args)) while (await GetLoopCondition(args))
{ {
if (cancellationToken.IsCancellationRequested) if (cancellationToken.IsCancellationRequested)
{ {
break; break;
} }
InferInternal(inferenceParams, args); await InferInternal(inferenceParams, args);
if (args.ReturnValue) if (args.ReturnValue)
{ {
@ -292,8 +299,8 @@ namespace LLama
yield return Context.TokenToString(id); yield return Context.TokenToString(id);
} }
var breakGeneration = PostProcess(inferenceParams, args, out var extraOutputs); var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
if (extraOutputs is not null) if (extraOutputs is { Count: > 0 })
{ {
foreach (var item in extraOutputs) foreach (var item in extraOutputs)
{ {
@ -307,21 +314,6 @@ namespace LLama
} }
} }
/// <summary>
/// Execute the inference asynchronously.
/// </summary>
/// <param name="text"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var result in Infer(text, inferenceParams, cancellationToken))
{
yield return result;
}
}
/// <summary> /// <summary>
/// State arguments that are used in single inference /// State arguments that are used in single inference
/// </summary> /// </summary>

View File

@ -5,9 +5,9 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Text;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Threading.Tasks;
using LLama.Extensions; using LLama.Extensions;
namespace LLama namespace LLama
@ -60,7 +60,7 @@ namespace LLama
return state; return state;
} }
/// <inheritdoc /> /// <inheritdoc />
public override void LoadState(ExecutorBaseState data) public override Task LoadState(ExecutorBaseState data)
{ {
if(data is InstructExecutorState state) if(data is InstructExecutorState state)
{ {
@ -81,34 +81,37 @@ namespace LLama
{ {
throw new ArgumentException("Invalid state data type."); throw new ArgumentException("Invalid state data type.");
} }
return Task.CompletedTask;
} }
/// <inheritdoc /> /// <inheritdoc />
public override void SaveState(string filename) public override async Task SaveState(string filename)
{ {
var state = (InstructExecutorState)GetStateData(); var state = (InstructExecutorState)GetStateData();
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{ {
JsonSerializer.Serialize(fs, state); await JsonSerializer.SerializeAsync(fs, state);
} }
} }
/// <inheritdoc /> /// <inheritdoc />
public override void LoadState(string filename) public override async Task LoadState(string filename)
{ {
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{ {
var state = JsonSerializer.Deserialize<InstructExecutorState>(fs); var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
LoadState(state); await LoadState(state);
} }
} }
/// <inheritdoc /> /// <inheritdoc />
protected override bool GetLoopCondition(InferStateArgs args) protected override Task<bool> GetLoopCondition(InferStateArgs args)
{ {
return args.RemainedTokens != 0 || _is_prompt_run; return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run);
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void PreprocessInputs(string text, InferStateArgs args) protected override Task PreprocessInputs(string text, InferStateArgs args)
{ {
args.Antiprompts ??= new List<string>(); args.Antiprompts ??= new List<string>();
args.Antiprompts.Add(_instructionPrefix); args.Antiprompts.Add(_instructionPrefix);
@ -133,23 +136,24 @@ namespace LLama
args.RemainedTokens -= line_inp.Length; args.RemainedTokens -= line_inp.Length;
} }
return Task.CompletedTask;
} }
/// <inheritdoc /> /// <inheritdoc />
protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
{ {
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount) if (_embed_inps.Count <= _consumedTokensCount)
{ {
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
{ {
args.WaitForInput = true; args.WaitForInput = true;
return true; return (true, Array.Empty<string>());
} }
if (_pastTokensCount > 0 && args.WaitForInput) if (_pastTokensCount > 0 && args.WaitForInput)
{ {
extraOutputs = new[] { "\n> " }; return (true, new[] { "\n> " });
return true;
} }
} }
@ -163,10 +167,11 @@ namespace LLama
args.RemainedTokens = inferenceParams.MaxTokens; args.RemainedTokens = inferenceParams.MaxTokens;
args.WaitForInput = true; args.WaitForInput = true;
} }
return false; return (false, Array.Empty<string>());
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{ {
if (_embeds.Count > 0) if (_embeds.Count > 0)
{ {
@ -230,6 +235,8 @@ namespace LLama
} }
} }
} }
return Task.CompletedTask;
} }
/// <summary> /// <summary>
/// The desciptor of the state of the instruct executor. /// The desciptor of the state of the instruct executor.

View File

@ -7,7 +7,7 @@ using System.IO;
using System.Linq; using System.Linq;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Text; using System.Threading.Tasks;
using LLama.Extensions; using LLama.Extensions;
namespace LLama namespace LLama
@ -51,7 +51,7 @@ namespace LLama
return state; return state;
} }
/// <inheritdoc /> /// <inheritdoc />
public override void LoadState(ExecutorBaseState data) public override Task LoadState(ExecutorBaseState data)
{ {
if (data is InteractiveExecutorState state) if (data is InteractiveExecutorState state)
{ {
@ -68,23 +68,25 @@ namespace LLama
} }
else else
throw new ArgumentException("Invalid state data type."); throw new ArgumentException("Invalid state data type.");
return Task.CompletedTask;
} }
/// <inheritdoc /> /// <inheritdoc />
public override void SaveState(string filename) public override async Task SaveState(string filename)
{ {
InteractiveExecutorState state = (InteractiveExecutorState)GetStateData(); var state = (InteractiveExecutorState)GetStateData();
using(FileStream fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{ {
JsonSerializer.Serialize(fs, state); await JsonSerializer.SerializeAsync(fs, state);
} }
} }
/// <inheritdoc /> /// <inheritdoc />
public override void LoadState(string filename) public override async Task LoadState(string filename)
{ {
using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{ {
var state = JsonSerializer.Deserialize<InteractiveExecutorState>(fs); var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
LoadState(state); await LoadState(state);
} }
} }
@ -92,13 +94,13 @@ namespace LLama
/// Define whether to continue the loop to generate responses. /// Define whether to continue the loop to generate responses.
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
protected override bool GetLoopCondition(InferStateArgs args) protected override Task<bool> GetLoopCondition(InferStateArgs args)
{ {
return args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run; return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run);
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void PreprocessInputs(string text, InferStateArgs args) protected override Task PreprocessInputs(string text, InferStateArgs args)
{ {
if (_is_prompt_run) if (_is_prompt_run)
{ {
@ -115,6 +117,8 @@ namespace LLama
_embed_inps.AddRange(line_inp); _embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length; args.RemainedTokens -= line_inp.Length;
} }
return Task.CompletedTask;
} }
/// <summary> /// <summary>
@ -122,24 +126,21 @@ namespace LLama
/// </summary> /// </summary>
/// <param name="inferenceParams"></param> /// <param name="inferenceParams"></param>
/// <param name="args"></param> /// <param name="args"></param>
/// <param name="extraOutputs"></param>
/// <returns></returns> /// <returns></returns>
protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs) protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
{ {
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount) if (_embed_inps.Count <= _consumedTokensCount)
{ {
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
args.WaitForInput = true; args.WaitForInput = true;
if (_pastTokensCount > 0 && args.WaitForInput) if (_pastTokensCount > 0 && args.WaitForInput)
return true; return (true, Array.Empty<string>());
} }
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
{ {
extraOutputs = new[] { " [end of text]\n" }; return (true, new[] { " [end of text]\n" });
return true;
} }
if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1) if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1)
@ -147,11 +148,12 @@ namespace LLama
args.RemainedTokens = inferenceParams.MaxTokens; args.RemainedTokens = inferenceParams.MaxTokens;
args.WaitForInput = true; args.WaitForInput = true;
} }
return false;
return (false, Array.Empty<string>());
} }
/// <inheritdoc /> /// <inheritdoc />
protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args) protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{ {
if (_embeds.Count > 0) if (_embeds.Count > 0)
{ {

View File

@ -55,7 +55,7 @@ namespace LLama
} }
/// <inheritdoc /> /// <inheritdoc />
public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{ {
using var context = _weights.CreateContext(_params); using var context = _weights.CreateContext(_params);
Context = context; Context = context;
@ -140,14 +140,5 @@ namespace LLama
{ {
return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding); return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding);
} }
/// <inheritdoc />
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var result in Infer(text, inferenceParams, cancellationToken))
{
yield return result;
}
}
} }
} }