Merge pull request #165 from martindevans/better_instruct_antiprompt_checking

better_instruct_antiprompt_checking
This commit is contained in:
Martin Evans 2023-09-11 00:32:43 +01:00 committed by GitHub
commit 466722dcff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 14 deletions

View File

@ -9,6 +9,13 @@ namespace LLama.Extensions
{
internal static class IReadOnlyListExtensions
{
/// <summary>
/// Find the index of `item` in `list`
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="list">list to search</param>
/// <param name="item">item to search for</param>
/// <returns></returns>
public static int? IndexOf<T>(this IReadOnlyList<T> list, T item)
where T : IEquatable<T>
{
@ -61,6 +68,14 @@ namespace LLama.Extensions
}
}
/// <summary>
/// Check if the given set of tokens ends with any of the given strings
/// </summary>
/// <param name="tokens">Tokens to check</param>
/// <param name="queries">Strings to search for</param>
/// <param name="model">Model to use to convert tokens into bytes</param>
/// <param name="encoding">Encoding to use to convert bytes into characters</param>
/// <returns></returns>
internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
{

View File

@ -489,6 +489,16 @@ namespace LLama
return NativeHandle.TokenToString(token, Encoding);
}
/// <summary>
/// Append a single token to a string builder
/// </summary>
/// <param name="token">Token to decode</param>
/// <param name="dest">string builder to append the result to</param>
public void TokenToString(llama_token token, StringBuilder dest)
{
NativeHandle.TokenToString(token, Encoding, dest);
}
/// <inheritdoc />
public void Dispose()
{

View File

@ -8,6 +8,7 @@ using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Extensions;
namespace LLama
{
@ -139,21 +140,10 @@ namespace LLama
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
{
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
{
var last_output_builder = new StringBuilder();
foreach (var token in _last_n_tokens)
Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder);
var last_output = last_output_builder.ToString();
foreach (var antiprompt in args.Antiprompts)
{
if (last_output.EndsWith(antiprompt))
{
args.WaitForInput = true;
return true;
}
}
args.WaitForInput = true;
return true;
}
if (_pastTokensCount > 0 && args.WaitForInput)