commit
9f9903c711
|
@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
|
|||
{
|
||||
public class ChatSessionStripRoleName
|
||||
{
|
||||
public static void Run()
|
||||
public static async Task Run()
|
||||
{
|
||||
Console.Write("Please input your model path: ");
|
||||
var modelPath = Console.ReadLine();
|
||||
|
@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
|
|||
Console.Write(prompt);
|
||||
while (true)
|
||||
{
|
||||
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
|
||||
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
|
||||
{
|
||||
Console.Write(text);
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
|
|||
{
|
||||
public class ChatSessionWithRoleName
|
||||
{
|
||||
public static void Run()
|
||||
public static async Task Run()
|
||||
{
|
||||
Console.Write("Please input your model path: ");
|
||||
var modelPath = Console.ReadLine();
|
||||
|
@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
|
|||
Console.Write(prompt);
|
||||
while (true)
|
||||
{
|
||||
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
|
||||
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
|
||||
{
|
||||
Console.Write(text);
|
||||
}
|
||||
|
|
|
@ -5,9 +5,9 @@ namespace LLama.Examples.NewVersion
|
|||
{
|
||||
public class GrammarJsonResponse
|
||||
{
|
||||
public static void Run()
|
||||
public static async Task Run()
|
||||
{
|
||||
var gbnf = File.ReadAllText("Assets/json.gbnf").Trim();
|
||||
var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim();
|
||||
var grammar = Grammar.Parse(gbnf, "root");
|
||||
|
||||
Console.Write("Please input your model path: ");
|
||||
|
@ -43,7 +43,7 @@ namespace LLama.Examples.NewVersion
|
|||
Console.ForegroundColor = ConsoleColor.White;
|
||||
Console.Write("Answer: ");
|
||||
prompt = $"Question: {prompt?.Trim()} Answer: ";
|
||||
foreach (var text in ex.Infer(prompt, inferenceParams))
|
||||
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
|
||||
{
|
||||
Console.Write(text);
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
|
|||
{
|
||||
public class InstructModeExecute
|
||||
{
|
||||
public static void Run()
|
||||
public static async Task Run()
|
||||
{
|
||||
Console.Write("Please input your model path: ");
|
||||
var modelPath = Console.ReadLine();
|
||||
|
@ -29,7 +29,7 @@ namespace LLama.Examples.NewVersion
|
|||
|
||||
while (true)
|
||||
{
|
||||
foreach (var text in executor.Infer(prompt, inferenceParams))
|
||||
await foreach (var text in executor.InferAsync(prompt, inferenceParams))
|
||||
{
|
||||
Console.Write(text);
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
|
|||
{
|
||||
public class SaveAndLoadSession
|
||||
{
|
||||
public static void Run()
|
||||
public static async Task Run()
|
||||
{
|
||||
Console.Write("Please input your model path: ");
|
||||
var modelPath = Console.ReadLine();
|
||||
|
@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
|
|||
Console.Write(prompt);
|
||||
while (true)
|
||||
{
|
||||
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
|
||||
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
|
||||
{
|
||||
Console.Write(text);
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
|
|||
{
|
||||
public class LoadAndSaveState
|
||||
{
|
||||
public static void Run()
|
||||
public static async Task Run()
|
||||
{
|
||||
Console.Write("Please input your model path: ");
|
||||
var modelPath = Console.ReadLine();
|
||||
|
@ -30,7 +30,7 @@ namespace LLama.Examples.NewVersion
|
|||
|
||||
while (true)
|
||||
{
|
||||
foreach (var text in ex.Infer(prompt, inferenceParams))
|
||||
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
|
||||
{
|
||||
Console.Write(text);
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
|
|||
{
|
||||
public class StatelessModeExecute
|
||||
{
|
||||
public static void Run()
|
||||
public static async Task Run()
|
||||
{
|
||||
Console.Write("Please input your model path: ");
|
||||
var modelPath = Console.ReadLine();
|
||||
|
@ -35,7 +35,7 @@ namespace LLama.Examples.NewVersion
|
|||
Console.ForegroundColor = ConsoleColor.White;
|
||||
Console.Write("Answer: ");
|
||||
prompt = $"Question: {prompt?.Trim()} Answer: ";
|
||||
foreach (var text in ex.Infer(prompt, inferenceParams))
|
||||
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
|
||||
{
|
||||
Console.Write(text);
|
||||
}
|
||||
|
|
|
@ -30,11 +30,11 @@
|
|||
|
||||
if (choice == 0)
|
||||
{
|
||||
ChatSessionWithRoleName.Run();
|
||||
await ChatSessionWithRoleName.Run();
|
||||
}
|
||||
else if (choice == 1)
|
||||
{
|
||||
ChatSessionStripRoleName.Run();
|
||||
await ChatSessionStripRoleName.Run();
|
||||
}
|
||||
else if(choice == 2)
|
||||
{
|
||||
|
@ -42,19 +42,19 @@
|
|||
}
|
||||
else if(choice == 3)
|
||||
{
|
||||
InstructModeExecute.Run();
|
||||
await InstructModeExecute.Run();
|
||||
}
|
||||
else if(choice == 4)
|
||||
{
|
||||
StatelessModeExecute.Run();
|
||||
await StatelessModeExecute.Run();
|
||||
}
|
||||
else if(choice == 5)
|
||||
{
|
||||
SaveAndLoadSession.Run();
|
||||
await SaveAndLoadSession.Run();
|
||||
}
|
||||
else if(choice == 6)
|
||||
{
|
||||
LoadAndSaveState.Run();
|
||||
await LoadAndSaveState.Run();
|
||||
}
|
||||
else if(choice == 7)
|
||||
{
|
||||
|
@ -70,7 +70,7 @@
|
|||
}
|
||||
else if (choice == 10)
|
||||
{
|
||||
GrammarJsonResponse.Run();
|
||||
await GrammarJsonResponse.Run();
|
||||
}
|
||||
else if (choice == 11)
|
||||
{
|
||||
|
|
|
@ -41,7 +41,7 @@ namespace LLama.Unittest
|
|||
}
|
||||
|
||||
[Fact]
|
||||
public void SampleWithTrivialGrammar()
|
||||
public async Task SampleWithTrivialGrammar()
|
||||
{
|
||||
// Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so
|
||||
// we can be confident it's not what the LLM would say if not constrained by the grammar!
|
||||
|
@ -66,7 +66,7 @@ namespace LLama.Unittest
|
|||
Grammar = grammar,
|
||||
};
|
||||
|
||||
var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList();
|
||||
var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync();
|
||||
|
||||
Assert.Equal("cat", result[0]);
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" />
|
||||
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
|
||||
<PackageReference Include="xunit" Version="2.5.0" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.0">
|
||||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||
|
|
|
@ -27,15 +27,15 @@ namespace LLama.Unittest
|
|||
}
|
||||
|
||||
[Fact]
|
||||
public void Stateless()
|
||||
public async Task Stateless()
|
||||
{
|
||||
var executor = new StatelessExecutor(_weights, _params);
|
||||
|
||||
const string question = "Question. what is a cat?\nAnswer: ";
|
||||
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };
|
||||
|
||||
var result1 = string.Join("", executor.Infer(question, @params));
|
||||
var result2 = string.Join("", executor.Infer(question, @params));
|
||||
var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
|
||||
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
|
||||
|
||||
_testOutputHelper.WriteLine(result1);
|
||||
|
||||
|
@ -44,7 +44,7 @@ namespace LLama.Unittest
|
|||
}
|
||||
|
||||
[Fact]
|
||||
public void OutOfContext()
|
||||
public async Task OutOfContext()
|
||||
{
|
||||
var executor = new StatelessExecutor(_weights, _params);
|
||||
|
||||
|
@ -58,8 +58,8 @@ namespace LLama.Unittest
|
|||
TokensKeep = question.Length,
|
||||
};
|
||||
|
||||
var result1 = string.Join("", executor.Infer(question, @params));
|
||||
var result2 = string.Join("", executor.Infer(question, @params));
|
||||
var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
|
||||
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
|
||||
|
||||
_testOutputHelper.WriteLine(result1);
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ namespace LLama.WebAPI.Controllers
|
|||
}
|
||||
|
||||
[HttpPost("Send")]
|
||||
public string SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service)
|
||||
public Task<string> SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service)
|
||||
{
|
||||
return _service.Send(input);
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ public class StatefulChatService : IDisposable
|
|||
_context?.Dispose();
|
||||
}
|
||||
|
||||
public string Send(SendMessageInput input)
|
||||
public async Task<string> Send(SendMessageInput input)
|
||||
{
|
||||
var userInput = input.Text;
|
||||
if (!_continue)
|
||||
|
@ -42,13 +42,13 @@ public class StatefulChatService : IDisposable
|
|||
Console.Write(input.Text);
|
||||
|
||||
Console.ForegroundColor = ConsoleColor.White;
|
||||
var outputs = _session.Chat(userInput, new Common.InferenceParams()
|
||||
var outputs = _session.ChatAsync(userInput, new Common.InferenceParams()
|
||||
{
|
||||
RepeatPenalty = 1.0f,
|
||||
AntiPrompts = new string[] { "User:" },
|
||||
});
|
||||
var result = "";
|
||||
foreach (var output in outputs)
|
||||
await foreach (var output in outputs)
|
||||
{
|
||||
Console.Write(output);
|
||||
result += output;
|
||||
|
|
|
@ -13,15 +13,6 @@ namespace LLama.Abstractions
|
|||
/// </summary>
|
||||
public LLamaContext Context { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Infers a response from the model.
|
||||
/// </summary>
|
||||
/// <param name="text">Your prompt</param>
|
||||
/// <param name="inferenceParams">Any additional parameters</param>
|
||||
/// <param name="token">A cancellation token.</param>
|
||||
/// <returns></returns>
|
||||
IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
|
||||
|
||||
/// <summary>
|
||||
/// Asynchronously infers a response from the model.
|
||||
/// </summary>
|
||||
|
|
|
@ -7,13 +7,6 @@ namespace LLama.Abstractions
|
|||
/// </summary>
|
||||
public interface ITextStreamTransform
|
||||
{
|
||||
/// <summary>
|
||||
/// Takes a stream of tokens and transforms them, returning a new stream of tokens.
|
||||
/// </summary>
|
||||
/// <param name="tokens"></param>
|
||||
/// <returns></returns>
|
||||
IEnumerable<string> Transform(IEnumerable<string> tokens);
|
||||
|
||||
/// <summary>
|
||||
/// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously.
|
||||
/// </summary>
|
||||
|
|
|
@ -134,26 +134,6 @@ namespace LLama
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the response from the LLama model with chat histories.
|
||||
/// </summary>
|
||||
/// <param name="history"></param>
|
||||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public IEnumerable<string> Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var prompt = HistoryTransform.HistoryToText(history);
|
||||
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
|
||||
StringBuilder sb = new();
|
||||
foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
|
||||
{
|
||||
yield return result;
|
||||
sb.Append(result);
|
||||
}
|
||||
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the response from the LLama model. Note that prompt could not only be the preset words,
|
||||
/// but also the question you want to ask.
|
||||
|
@ -162,15 +142,14 @@ namespace LLama
|
|||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public IEnumerable<string> Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
foreach(var inputTransform in InputTransformPipeline)
|
||||
{
|
||||
prompt = inputTransform.Transform(prompt);
|
||||
}
|
||||
|
||||
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
|
||||
StringBuilder sb = new();
|
||||
foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
|
||||
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
|
||||
{
|
||||
yield return result;
|
||||
sb.Append(result);
|
||||
|
@ -198,35 +177,6 @@ namespace LLama
|
|||
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the response from the LLama model with chat histories asynchronously.
|
||||
/// </summary>
|
||||
/// <param name="prompt"></param>
|
||||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
foreach (var inputTransform in InputTransformPipeline)
|
||||
{
|
||||
prompt = inputTransform.Transform(prompt);
|
||||
}
|
||||
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
|
||||
StringBuilder sb = new();
|
||||
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
|
||||
{
|
||||
yield return result;
|
||||
sb.Append(result);
|
||||
}
|
||||
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
|
||||
}
|
||||
|
||||
private IEnumerable<string> ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var results = _executor.Infer(prompt, inferenceParams, cancellationToken);
|
||||
return OutputTransform.Transform(results);
|
||||
}
|
||||
|
||||
private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
|
||||
|
|
|
@ -10,6 +10,7 @@ using System.Linq;
|
|||
using System.Runtime.CompilerServices;
|
||||
using System.Text.Json.Serialization;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace LLama
|
||||
{
|
||||
|
@ -212,47 +213,53 @@ namespace LLama
|
|||
/// </summary>
|
||||
/// <param name="args"></param>
|
||||
/// <returns></returns>
|
||||
protected abstract bool GetLoopCondition(InferStateArgs args);
|
||||
protected abstract Task<bool> GetLoopCondition(InferStateArgs args);
|
||||
|
||||
/// <summary>
|
||||
/// Preprocess the inputs before the inference.
|
||||
/// </summary>
|
||||
/// <param name="text"></param>
|
||||
/// <param name="args"></param>
|
||||
protected abstract void PreprocessInputs(string text, InferStateArgs args);
|
||||
protected abstract Task PreprocessInputs(string text, InferStateArgs args);
|
||||
|
||||
/// <summary>
|
||||
/// Do some post processing after the inference.
|
||||
/// </summary>
|
||||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="args"></param>
|
||||
/// <param name="extraOutputs"></param>
|
||||
/// <returns></returns>
|
||||
protected abstract bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs);
|
||||
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
|
||||
|
||||
/// <summary>
|
||||
/// The core inference logic.
|
||||
/// </summary>
|
||||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="args"></param>
|
||||
protected abstract void InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
|
||||
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
|
||||
|
||||
/// <summary>
|
||||
/// Save the current state to a file.
|
||||
/// </summary>
|
||||
/// <param name="filename"></param>
|
||||
public abstract void SaveState(string filename);
|
||||
public abstract Task SaveState(string filename);
|
||||
|
||||
/// <summary>
|
||||
/// Get the current state data.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public abstract ExecutorBaseState GetStateData();
|
||||
|
||||
/// <summary>
|
||||
/// Load the state from data.
|
||||
/// </summary>
|
||||
/// <param name="data"></param>
|
||||
public abstract void LoadState(ExecutorBaseState data);
|
||||
public abstract Task LoadState(ExecutorBaseState data);
|
||||
|
||||
/// <summary>
|
||||
/// Load the state from a file.
|
||||
/// </summary>
|
||||
/// <param name="filename"></param>
|
||||
public abstract void LoadState(string filename);
|
||||
public abstract Task LoadState(string filename);
|
||||
|
||||
|
||||
/// <summary>
|
||||
|
@ -262,12 +269,12 @@ namespace LLama
|
|||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public virtual IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
cancellationToken.ThrowIfCancellationRequested();
|
||||
inferenceParams ??= new InferenceParams();
|
||||
|
||||
InferStateArgs args = new InferStateArgs()
|
||||
var args = new InferStateArgs
|
||||
{
|
||||
Antiprompts = inferenceParams.AntiPrompts.ToList(),
|
||||
RemainedTokens = inferenceParams.MaxTokens,
|
||||
|
@ -276,15 +283,15 @@ namespace LLama
|
|||
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
|
||||
};
|
||||
|
||||
PreprocessInputs(text, args);
|
||||
await PreprocessInputs(text, args);
|
||||
|
||||
while (GetLoopCondition(args))
|
||||
while (await GetLoopCondition(args))
|
||||
{
|
||||
if (cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
break;
|
||||
}
|
||||
InferInternal(inferenceParams, args);
|
||||
await InferInternal(inferenceParams, args);
|
||||
|
||||
if (args.ReturnValue)
|
||||
{
|
||||
|
@ -292,8 +299,8 @@ namespace LLama
|
|||
yield return Context.TokenToString(id);
|
||||
}
|
||||
|
||||
var breakGeneration = PostProcess(inferenceParams, args, out var extraOutputs);
|
||||
if (extraOutputs is not null)
|
||||
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
|
||||
if (extraOutputs is { Count: > 0 })
|
||||
{
|
||||
foreach (var item in extraOutputs)
|
||||
{
|
||||
|
@ -307,21 +314,6 @@ namespace LLama
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Execute the inference asynchronously.
|
||||
/// </summary>
|
||||
/// <param name="text"></param>
|
||||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="cancellationToken"></param>
|
||||
/// <returns></returns>
|
||||
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
foreach (var result in Infer(text, inferenceParams, cancellationToken))
|
||||
{
|
||||
yield return result;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// State arguments that are used in single inference
|
||||
/// </summary>
|
||||
|
|
|
@ -5,9 +5,9 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using System.Threading.Tasks;
|
||||
using LLama.Extensions;
|
||||
|
||||
namespace LLama
|
||||
|
@ -60,7 +60,7 @@ namespace LLama
|
|||
return state;
|
||||
}
|
||||
/// <inheritdoc />
|
||||
public override void LoadState(ExecutorBaseState data)
|
||||
public override Task LoadState(ExecutorBaseState data)
|
||||
{
|
||||
if(data is InstructExecutorState state)
|
||||
{
|
||||
|
@ -81,34 +81,37 @@ namespace LLama
|
|||
{
|
||||
throw new ArgumentException("Invalid state data type.");
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override void SaveState(string filename)
|
||||
public override async Task SaveState(string filename)
|
||||
{
|
||||
var state = (InstructExecutorState)GetStateData();
|
||||
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
|
||||
{
|
||||
JsonSerializer.Serialize(fs, state);
|
||||
await JsonSerializer.SerializeAsync(fs, state);
|
||||
}
|
||||
}
|
||||
/// <inheritdoc />
|
||||
public override void LoadState(string filename)
|
||||
public override async Task LoadState(string filename)
|
||||
{
|
||||
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
|
||||
{
|
||||
var state = JsonSerializer.Deserialize<InstructExecutorState>(fs);
|
||||
LoadState(state);
|
||||
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
|
||||
await LoadState(state);
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override bool GetLoopCondition(InferStateArgs args)
|
||||
protected override Task<bool> GetLoopCondition(InferStateArgs args)
|
||||
{
|
||||
return args.RemainedTokens != 0 || _is_prompt_run;
|
||||
return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void PreprocessInputs(string text, InferStateArgs args)
|
||||
protected override Task PreprocessInputs(string text, InferStateArgs args)
|
||||
{
|
||||
args.Antiprompts ??= new List<string>();
|
||||
args.Antiprompts.Add(_instructionPrefix);
|
||||
|
@ -133,23 +136,24 @@ namespace LLama
|
|||
|
||||
args.RemainedTokens -= line_inp.Length;
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
|
||||
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
|
||||
{
|
||||
extraOutputs = null;
|
||||
if (_embed_inps.Count <= _consumedTokensCount)
|
||||
{
|
||||
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
|
||||
{
|
||||
args.WaitForInput = true;
|
||||
return true;
|
||||
return (true, Array.Empty<string>());
|
||||
}
|
||||
|
||||
if (_pastTokensCount > 0 && args.WaitForInput)
|
||||
{
|
||||
extraOutputs = new[] { "\n> " };
|
||||
return true;
|
||||
return (true, new[] { "\n> " });
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -163,10 +167,11 @@ namespace LLama
|
|||
args.RemainedTokens = inferenceParams.MaxTokens;
|
||||
args.WaitForInput = true;
|
||||
}
|
||||
return false;
|
||||
return (false, Array.Empty<string>());
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
|
||||
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
|
||||
{
|
||||
if (_embeds.Count > 0)
|
||||
{
|
||||
|
@ -230,6 +235,8 @@ namespace LLama
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
/// <summary>
|
||||
/// The desciptor of the state of the instruct executor.
|
||||
|
|
|
@ -7,7 +7,7 @@ using System.IO;
|
|||
using System.Linq;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using LLama.Extensions;
|
||||
|
||||
namespace LLama
|
||||
|
@ -51,7 +51,7 @@ namespace LLama
|
|||
return state;
|
||||
}
|
||||
/// <inheritdoc />
|
||||
public override void LoadState(ExecutorBaseState data)
|
||||
public override Task LoadState(ExecutorBaseState data)
|
||||
{
|
||||
if (data is InteractiveExecutorState state)
|
||||
{
|
||||
|
@ -68,23 +68,25 @@ namespace LLama
|
|||
}
|
||||
else
|
||||
throw new ArgumentException("Invalid state data type.");
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
/// <inheritdoc />
|
||||
public override void SaveState(string filename)
|
||||
public override async Task SaveState(string filename)
|
||||
{
|
||||
InteractiveExecutorState state = (InteractiveExecutorState)GetStateData();
|
||||
using(FileStream fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
|
||||
var state = (InteractiveExecutorState)GetStateData();
|
||||
using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
|
||||
{
|
||||
JsonSerializer.Serialize(fs, state);
|
||||
await JsonSerializer.SerializeAsync(fs, state);
|
||||
}
|
||||
}
|
||||
/// <inheritdoc />
|
||||
public override void LoadState(string filename)
|
||||
public override async Task LoadState(string filename)
|
||||
{
|
||||
using (FileStream fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
|
||||
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
|
||||
{
|
||||
var state = JsonSerializer.Deserialize<InteractiveExecutorState>(fs);
|
||||
LoadState(state);
|
||||
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
|
||||
await LoadState(state);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -92,13 +94,13 @@ namespace LLama
|
|||
/// Define whether to continue the loop to generate responses.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
protected override bool GetLoopCondition(InferStateArgs args)
|
||||
protected override Task<bool> GetLoopCondition(InferStateArgs args)
|
||||
{
|
||||
return args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run;
|
||||
return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void PreprocessInputs(string text, InferStateArgs args)
|
||||
protected override Task PreprocessInputs(string text, InferStateArgs args)
|
||||
{
|
||||
if (_is_prompt_run)
|
||||
{
|
||||
|
@ -115,6 +117,8 @@ namespace LLama
|
|||
_embed_inps.AddRange(line_inp);
|
||||
args.RemainedTokens -= line_inp.Length;
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -122,24 +126,21 @@ namespace LLama
|
|||
/// </summary>
|
||||
/// <param name="inferenceParams"></param>
|
||||
/// <param name="args"></param>
|
||||
/// <param name="extraOutputs"></param>
|
||||
/// <returns></returns>
|
||||
protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable<string>? extraOutputs)
|
||||
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
|
||||
{
|
||||
extraOutputs = null;
|
||||
if (_embed_inps.Count <= _consumedTokensCount)
|
||||
{
|
||||
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
|
||||
args.WaitForInput = true;
|
||||
|
||||
if (_pastTokensCount > 0 && args.WaitForInput)
|
||||
return true;
|
||||
return (true, Array.Empty<string>());
|
||||
}
|
||||
|
||||
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
|
||||
{
|
||||
extraOutputs = new[] { " [end of text]\n" };
|
||||
return true;
|
||||
return (true, new[] { " [end of text]\n" });
|
||||
}
|
||||
|
||||
if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1)
|
||||
|
@ -147,11 +148,12 @@ namespace LLama
|
|||
args.RemainedTokens = inferenceParams.MaxTokens;
|
||||
args.WaitForInput = true;
|
||||
}
|
||||
return false;
|
||||
|
||||
return (false, Array.Empty<string>());
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
|
||||
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
|
||||
{
|
||||
if (_embeds.Count > 0)
|
||||
{
|
||||
|
|
|
@ -55,7 +55,7 @@ namespace LLama
|
|||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
|
||||
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
using var context = _weights.CreateContext(_params);
|
||||
Context = context;
|
||||
|
@ -140,14 +140,5 @@ namespace LLama
|
|||
{
|
||||
return tokens.TokensEndsWithAnyString(antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
foreach (var result in Infer(text, inferenceParams, cancellationToken))
|
||||
{
|
||||
yield return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue