Merge pull request #165 from martindevans/better_instruct_antiprompt_checking
better_instruct_antiprompt_checking
This commit is contained in:
commit
466722dcff
|
@ -9,6 +9,13 @@ namespace LLama.Extensions
|
|||
{
|
||||
internal static class IReadOnlyListExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Find the index of `item` in `list`
|
||||
/// </summary>
|
||||
/// <typeparam name="T"></typeparam>
|
||||
/// <param name="list">list to search</param>
|
||||
/// <param name="item">item to search for</param>
|
||||
/// <returns></returns>
|
||||
public static int? IndexOf<T>(this IReadOnlyList<T> list, T item)
|
||||
where T : IEquatable<T>
|
||||
{
|
||||
|
@ -61,6 +68,14 @@ namespace LLama.Extensions
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Check if the given set of tokens ends with any of the given strings
|
||||
/// </summary>
|
||||
/// <param name="tokens">Tokens to check</param>
|
||||
/// <param name="queries">Strings to search for</param>
|
||||
/// <param name="model">Model to use to convert tokens into bytes</param>
|
||||
/// <param name="encoding">Encoding to use to convert bytes into characters</param>
|
||||
/// <returns></returns>
|
||||
internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
|
||||
where TTokens : IReadOnlyList<int>
|
||||
{
|
||||
|
|
|
@ -489,6 +489,16 @@ namespace LLama
|
|||
return NativeHandle.TokenToString(token, Encoding);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Append a single token to a string builder
|
||||
/// </summary>
|
||||
/// <param name="token">Token to decode</param>
|
||||
/// <param name="dest">string builder to append the result to</param>
|
||||
public void TokenToString(llama_token token, StringBuilder dest)
|
||||
{
|
||||
NativeHandle.TokenToString(token, Encoding, dest);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
|
|
|
@ -8,6 +8,7 @@ using System.Linq;
|
|||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using LLama.Extensions;
|
||||
|
||||
namespace LLama
|
||||
{
|
||||
|
@ -139,21 +140,10 @@ namespace LLama
|
|||
extraOutputs = null;
|
||||
if (_embed_inps.Count <= _consumedTokensCount)
|
||||
{
|
||||
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
|
||||
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
|
||||
{
|
||||
var last_output_builder = new StringBuilder();
|
||||
foreach (var token in _last_n_tokens)
|
||||
Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder);
|
||||
var last_output = last_output_builder.ToString();
|
||||
|
||||
foreach (var antiprompt in args.Antiprompts)
|
||||
{
|
||||
if (last_output.EndsWith(antiprompt))
|
||||
{
|
||||
args.WaitForInput = true;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
args.WaitForInput = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (_pastTokensCount > 0 && args.WaitForInput)
|
||||
|
|
Loading…
Reference in New Issue