From d08a1250205398f41204fc1a29f5a22c9e1accb3 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Mon, 11 Sep 2023 00:22:17 +0100 Subject: [PATCH] Using the `TokensEndsWithAnyString` extensions for antiprompt checking in instruct executor. Simpler and more efficient. --- LLama/Extensions/IReadOnlyListExtensions.cs | 15 +++++++++++++++ LLama/LLamaContext.cs | 10 ++++++++++ LLama/LLamaInstructExecutor.cs | 18 ++++-------------- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs index b07d90cf..131a8852 100644 --- a/LLama/Extensions/IReadOnlyListExtensions.cs +++ b/LLama/Extensions/IReadOnlyListExtensions.cs @@ -9,6 +9,13 @@ namespace LLama.Extensions { internal static class IReadOnlyListExtensions { + /// + /// Find the index of `item` in `list` + /// + /// + /// list to search + /// item to search for + /// public static int? IndexOf(this IReadOnlyList list, T item) where T : IEquatable { @@ -61,6 +68,14 @@ namespace LLama.Extensions } } + /// + /// Check if the given set of tokens ends with any of the given strings + /// + /// Tokens to check + /// Strings to search for + /// Model to use to convert tokens into bytes + /// Encoding to use to convert bytes into characters + /// internal static bool TokensEndsWithAnyString(this TTokens tokens, IList? queries, SafeLlamaModelHandle model, Encoding encoding) where TTokens : IReadOnlyList { diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 5b3853fc..7b7647e7 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -484,6 +484,16 @@ namespace LLama return NativeHandle.TokenToString(token, Encoding); } + /// + /// Append a single token to a string builder + /// + /// Token to decode + /// string builder to append the result to + public void TokenToString(llama_token token, StringBuilder dest) + { + NativeHandle.TokenToString(token, Encoding, dest); + } + /// public void Dispose() { diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 712c2c23..2d46728f 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; +using LLama.Extensions; namespace LLama { @@ -139,21 +140,10 @@ namespace LLama extraOutputs = null; if (_embed_inps.Count <= _consumedTokensCount) { - if (args.Antiprompts is not null && args.Antiprompts.Count > 0) + if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) { - var last_output_builder = new StringBuilder(); - foreach (var token in _last_n_tokens) - Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder); - var last_output = last_output_builder.ToString(); - - foreach (var antiprompt in args.Antiprompts) - { - if (last_output.EndsWith(antiprompt)) - { - args.WaitForInput = true; - return true; - } - } + args.WaitForInput = true; + return true; } if (_pastTokensCount > 0 && args.WaitForInput)