Minimal changes required to remove non-async inference.

This commit is contained in:
Martin Evans 2023-09-14 21:04:14 +01:00
parent b1e9d8240d
commit 3f80190f85
20 changed files with 108 additions and 181 deletions

View File

@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class ChatSessionStripRoleName
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
Console.Write(prompt);
while (true)
{
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{
Console.Write(text);
}

View File

@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class ChatSessionWithRoleName
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
Console.Write(prompt);
while (true)
{
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{
Console.Write(text);
}

View File

@ -5,9 +5,9 @@ namespace LLama.Examples.NewVersion
{
public class GrammarJsonResponse
{
public static void Run()
public static async Task Run()
{
var gbnf = File.ReadAllText("Assets/json.gbnf").Trim();
var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim();
var grammar = Grammar.Parse(gbnf, "root");
Console.Write("Please input your model path: ");
@ -43,7 +43,7 @@ namespace LLama.Examples.NewVersion
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams))
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}

View File

@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class InstructModeExecute
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@ -29,7 +29,7 @@ namespace LLama.Examples.NewVersion
while (true)
{
foreach (var text in executor.Infer(prompt, inferenceParams))
await foreach (var text in executor.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}

View File

@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class SaveAndLoadSession
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
Console.Write(prompt);
while (true)
{
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{
Console.Write(text);
}

View File

@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class LoadAndSaveState
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
while (true)
{
foreach (var text in ex.Infer(prompt, inferenceParams))
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}

View File

@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class StatelessModeExecute
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
@ -35,7 +35,7 @@ namespace LLama.Examples.NewVersion
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams))
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}

View File

@ -29,11 +29,11 @@
if (choice == 0)
{
ChatSessionWithRoleName.Run();
await ChatSessionWithRoleName.Run();
}
else if (choice == 1)
{
ChatSessionStripRoleName.Run();
await ChatSessionStripRoleName.Run();
}
else if(choice == 2)
{
@ -41,19 +41,19 @@
}
else if(choice == 3)
{
InstructModeExecute.Run();
await InstructModeExecute.Run();
}
else if(choice == 4)
{
StatelessModeExecute.Run();
await StatelessModeExecute.Run();
}
else if(choice == 5)
{
SaveAndLoadSession.Run();
await SaveAndLoadSession.Run();
}
else if(choice == 6)
{
LoadAndSaveState.Run();
await LoadAndSaveState.Run();
}
else if(choice == 7)
{
@ -69,7 +69,7 @@
}
else if (choice == 10)
{
GrammarJsonResponse.Run();
await GrammarJsonResponse.Run();
}
else if (choice == 11)
{

View File

@ -41,7 +41,7 @@ namespace LLama.Unittest
}
[Fact]
public void SampleWithTrivialGrammar()
public async Task SampleWithTrivialGrammar()
{
// Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so
// we can be confident it's not what the LLM would say if not constrained by the grammar!
@ -66,7 +66,7 @@ namespace LLama.Unittest
Grammar = grammar,
};
var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList();
var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync();
Assert.Equal("cat", result[0]);
}

View File

@ -12,6 +12,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
<PackageReference Include="xunit" Version="2.5.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>

View File

@ -27,15 +27,15 @@ namespace LLama.Unittest
}
[Fact]
public void Stateless()
public async Task Stateless()
{
var executor = new StatelessExecutor(_weights, _params);
const string question = "Question. what is a cat?\nAnswer: ";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };
var result1 = string.Join("", executor.Infer(question, @params));
var result2 = string.Join("", executor.Infer(question, @params));
var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
_testOutputHelper.WriteLine(result1);
@ -44,7 +44,7 @@ namespace LLama.Unittest
}
[Fact]
public void OutOfContext()
public async Task OutOfContext()
{
var executor = new StatelessExecutor(_weights, _params);
@ -58,8 +58,8 @@ namespace LLama.Unittest
TokensKeep = question.Length,
};
var result1 = string.Join("", executor.Infer(question, @params));
var result2 = string.Join("", executor.Infer(question, @params));
var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
_testOutputHelper.WriteLine(result1);

View File

@ -18,7 +18,7 @@ namespace LLama.WebAPI.Controllers
}
[HttpPost("Send")]
public string SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service)
public Task<string> SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service)
{
return _service.Send(input);
}

View File

@ -28,7 +28,7 @@ public class StatefulChatService : IDisposable
_context?.Dispose();
}
public string Send(SendMessageInput input)
public async Task<string> Send(SendMessageInput input)
{
var userInput = input.Text;
if (!_continue)
@ -42,13 +42,13 @@ public class StatefulChatService : IDisposable
Console.Write(input.Text);
Console.ForegroundColor = ConsoleColor.White;
var outputs = _session.Chat(userInput, new Common.InferenceParams()
var outputs = _session.ChatAsync(userInput, new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
});
var result = "";
foreach (var output in outputs)
await foreach (var output in outputs)
{
Console.Write(output);
result += output;

View File

@ -13,15 +13,6 @@ namespace LLama.Abstractions
/// </summary>
public LLamaContext Context { get; }
/// <summary>
/// Infers a response from the model.
/// </summary>
/// <param name="text">Your prompt</param>
/// <param name="inferenceParams">Any additional parameters</param>
/// <param name="token">A cancellation token.</param>
/// <returns></returns>
IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
/// <summary>
/// Asynchronously infers a response from the model.
/// </summary>

View File

@ -7,13 +7,6 @@ namespace LLama.Abstractions
/// </summary>
public interface ITextStreamTransform
{
/// <summary>
/// Takes a stream of tokens and transforms them, returning a new stream of tokens.
/// </summary>
/// <param name="tokens"></param>
/// <returns></returns>
IEnumerable<string> Transform(IEnumerable<string> tokens);
/// <summary>
/// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously.
/// </summary>

View File

@ -134,26 +134,6 @@ namespace LLama
}
}
/// <summary>
/// Get the response from the LLama model with chat histories.
/// </summary>
/// <param name="history"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public IEnumerable<string> Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
var prompt = HistoryTransform.HistoryToText(history);
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new();
foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}
/// <summary>
/// Get the response from the LLama model. Note that prompt could not only be the preset words,
/// but also the question you want to ask.
@ -162,15 +142,14 @@ namespace LLama
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public IEnumerable<string> Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach(var inputTransform in InputTransformPipeline)
{
prompt = inputTransform.Transform(prompt);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new();
foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
@ -198,35 +177,6 @@ namespace LLama
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}
/// <summary>
/// Get the response from the LLama model with chat histories asynchronously.
/// </summary>
/// <param name="prompt"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var inputTransform in InputTransformPipeline)
{
prompt = inputTransform.Transform(prompt);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new();
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}
private IEnumerable<string> ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
var results = _executor.Infer(prompt, inferenceParams, cancellationToken);
return OutputTransform.Transform(results);
}
private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);

View File

@ -10,6 +10,7 @@ using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
namespace LLama
{
@ -212,47 +213,53 @@ namespace LLama
/// </summary>
/// <param name="args"></param>
/// <returns></returns>
protected abstract bool GetLoopCondition(InferStateArgs args);
protected abstract Task<bool> GetLoopCondition(InferStateArgs args);
/// <summary>
/// Preprocess the inputs before the inference.
/// </summary>
/// <param name="text"></param>
/// <param name="args"></param>
protected abstract void PreprocessInputs(string text, InferStateArgs args);
protected abstract Task PreprocessInputs(string text, InferStateArgs args);
/// <summary>
/// Do some post processing after the inference.
/// </summary>
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <param name="extraOutputs"></param>
/// <returns></returns>
protected abstract bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs);
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
/// <summary>
/// The core inference logic.
/// </summary>
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
protected abstract void InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
/// <summary>
/// Save the current state to a file.
/// </summary>
/// <param name="filename"></param>
public abstract void SaveState(string filename);
public abstract Task SaveState(string filename);
/// <summary>
/// Get the current state data.
/// </summary>
/// <returns></returns>
public abstract ExecutorBaseState GetStateData();
/// <summary>
/// Load the state from data.
/// </summary>
/// <param name="data"></param>
public abstract void LoadState(ExecutorBaseState data);
public abstract Task LoadState(ExecutorBaseState data);
/// <summary>
/// Load the state from a file.
/// </summary>
/// <param name="filename"></param>
public abstract void LoadState(string filename);
public abstract Task LoadState(string filename);
/// <summary>
@ -262,12 +269,12 @@ namespace LLama
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public virtual IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
inferenceParams ??= new InferenceParams();
InferStateArgs args = new InferStateArgs()
var args = new InferStateArgs
{
Antiprompts = inferenceParams.AntiPrompts.ToList(),
RemainedTokens = inferenceParams.MaxTokens,
@ -276,15 +283,15 @@ namespace LLama
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
};
PreprocessInputs(text, args);
await PreprocessInputs(text, args);
while (GetLoopCondition(args))
while (await GetLoopCondition(args))
{
if (cancellationToken.IsCancellationRequested)
{
break;
}
InferInternal(inferenceParams, args);
await InferInternal(inferenceParams, args);
if (args.ReturnValue)
{
@ -292,8 +299,8 @@ namespace LLama
yield return Context.TokenToString(id);
}
var breakGeneration = PostProcess(inferenceParams, args, out var extraOutputs);
if (extraOutputs is not null)
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
if (extraOutputs is { Count: > 0 })
{
foreach (var item in extraOutputs)
{
@ -307,21 +314,6 @@ namespace LLama
}
}
/// <summary>
/// Execute the inference asynchronously.
/// </summary>
/// <param name="text"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var result in Infer(text, inferenceParams, cancellationToken))
{
yield return result;
}
}
/// <summary>
/// State arguments that are used in single inference
/// </summary>

View File

@ -5,9 +5,9 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using LLama.Extensions;
namespace LLama
@ -60,7 +60,7 @@ namespace LLama
return state;
}
/// <inheritdoc />
public override void LoadState(ExecutorBaseState data)
public override Task LoadState(ExecutorBaseState data)
{
if(data is InstructExecutorState state)
{
@ -81,34 +81,37 @@ namespace LLama
{
throw new ArgumentException("Invalid state data type.");
}
return Task.CompletedTask;
}
/// <inheritdoc />
public override void SaveState(string filename)
public override async Task SaveState(string filename)
{
var state = (InstructExecutorState)GetStateData();
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
JsonSerializer.Serialize(fs, state);
await JsonSerializer.SerializeAsync(fs, state);
}
}
/// <inheritdoc />
public override void LoadState(string filename)
public override async Task LoadState(string filename)
{
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = JsonSerializer.Deserialize<InstructExecutorState>(fs);
LoadState(state);
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
await LoadState(state);
}
}
/// <inheritdoc />
protected override bool GetLoopCondition(InferStateArgs args)
protected override Task<bool> GetLoopCondition(InferStateArgs args)
{
return args.RemainedTokens != 0 || _is_prompt_run;
return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run);
}
/// <inheritdoc />
protected override void PreprocessInputs(string text, InferStateArgs args)
protected override Task PreprocessInputs(string text, InferStateArgs args)
{
args.Antiprompts ??= new List<string>();
args.Antiprompts.Add(_instructionPrefix);
@ -133,23 +136,24 @@ namespace LLama
args.RemainedTokens -= line_inp.Length;
}
return Task.CompletedTask;
}
/// <inheritdoc />
protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
{
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
{
args.WaitForInput = true;
return true;
return (true, Array.Empty<string>());
}
if (_pastTokensCount > 0 && args.WaitForInput)
{
extraOutputs = new[] { "\n> " };
return true;
return (true, new[] { "\n> " });
}
}
@ -163,10 +167,11 @@ namespace LLama
args.RemainedTokens = inferenceParams.MaxTokens;
args.WaitForInput = true;
}
return false;
return (false, Array.Empty<string>());
}
/// <inheritdoc />
protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
if (_embeds.Count > 0)
{
@ -230,6 +235,8 @@ namespace LLama
}
}
}
return Task.CompletedTask;
}
/// <summary>
/// The desciptor of the state of the instruct executor.

View File

@ -7,7 +7,7 @@ using System.IO;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text;
using System.Threading.Tasks;
using LLama.Extensions;
namespace LLama
@ -51,7 +51,7 @@ namespace LLama
return state;
}
/// <inheritdoc />
public override void LoadState(ExecutorBaseState data)
public override Task LoadState(ExecutorBaseState data)
{
if (data is InteractiveExecutorState state)
{
@ -68,23 +68,25 @@ namespace LLama
}
else
throw new ArgumentException("Invalid state data type.");
return Task.CompletedTask;
}
/// <inheritdoc />
public override void SaveState(string filename)
public override async Task SaveState(string filename)
{
InteractiveExecutorState state = (InteractiveExecutorState)GetStateData();
using(FileStream fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
var state = (InteractiveExecutorState)GetStateData();
using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
JsonSerializer.Serialize(fs, state);
await JsonSerializer.SerializeAsync(fs, state);
}
}
/// <inheritdoc />
public override void LoadState(string filename)
public override async Task LoadState(string filename)
{
using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = JsonSerializer.Deserialize<InteractiveExecutorState>(fs);
LoadState(state);
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state);
}
}
@ -92,13 +94,13 @@ namespace LLama
/// Define whether to continue the loop to generate responses.
/// </summary>
/// <returns></returns>
protected override bool GetLoopCondition(InferStateArgs args)
protected override Task<bool> GetLoopCondition(InferStateArgs args)
{
return args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run;
return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run);
}
/// <inheritdoc />
protected override void PreprocessInputs(string text, InferStateArgs args)
protected override Task PreprocessInputs(string text, InferStateArgs args)
{
if (_is_prompt_run)
{
@ -115,6 +117,8 @@ namespace LLama
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
}
return Task.CompletedTask;
}
/// <summary>
@ -122,24 +126,21 @@ namespace LLama
/// </summary>
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <param name="extraOutputs"></param>
/// <returns></returns>
protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
{
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
args.WaitForInput = true;
if (_pastTokensCount > 0 && args.WaitForInput)
return true;
return (true, Array.Empty<string>());
}
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
{
extraOutputs = new[] { " [end of text]\n" };
return true;
return (true, new[] { " [end of text]\n" });
}
if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1)
@ -147,11 +148,12 @@ namespace LLama
args.RemainedTokens = inferenceParams.MaxTokens;
args.WaitForInput = true;
}
return false;
return (false, Array.Empty<string>());
}
/// <inheritdoc />
protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
if (_embeds.Count > 0)
{

View File

@ -55,7 +55,7 @@ namespace LLama
}
/// <inheritdoc />
public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
using var context = _weights.CreateContext(_params);
Context = context;
@ -140,14 +140,5 @@ namespace LLama
{
return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding);
}
/// <inheritdoc />
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var result in Infer(text, inferenceParams, cancellationToken))
{
yield return result;
}
}
}
}