Added two new classes for detokenization tasks:

- `AntipromptProcessor` accepts chunks of text and returns a value indicating if any antiprompt has been detected.
 - `StreamingTokenDecoder` decodes tokens into text, maintaining some internal state to handle single characters which are encoded as multiple tokens.

Added tests for these classes and updated StatelessExecutor to use them.

Removed most DeTokenize methods, marked the rest as obsolete (should always use a `StreamingTokenDecoder`).
This commit is contained in:
Martin Evans 2023-10-23 00:33:50 +01:00
parent efdf3d630c
commit 51d4411a58
10 changed files with 387 additions and 105 deletions

View File

@ -32,13 +32,14 @@ public sealed class BeamTests
{
const int num_beams = 2;
const int n_predict = 3;
const string prompt = "The cat sat on";
var context = _model.CreateContext(_params);
var result = new StringBuilder();
var initial_tokens = context.Tokenize("The cat sat on");
result.Append(context.DeTokenize(initial_tokens.ToArray()));
var initial_tokens = context.Tokenize(prompt);
result.Append(prompt);
context.Eval(initial_tokens, 0);
NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>

View File

@ -73,6 +73,72 @@ public sealed class TokenTests
Assert.False(result);
}
[Fact]
public void TokensEndWith2()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
decoder.AddRange(tokens);
var processor = new AntipromptProcessor(new[]
{
"a fish",
"the mat",
"this is an improbably long query to be using for this method"
});
var result = processor.Add(decoder.Read());
Assert.True(result);
}
[Fact]
public void TokensEndSubstring2()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
decoder.AddRange(tokens);
var processor = new AntipromptProcessor(new[] { "at" });
var result = processor.Add(decoder.Read());
Assert.True(result);
}
[Fact]
public void TokensNotEndWith2()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
decoder.AddRange(tokens);
var processor = new AntipromptProcessor(new[]
{
"a fish",
"The cat sat on the edge of the ma",
"this is an improbably long query to be using for this method"
});
var result = processor.Add(decoder.Read());
Assert.False(result);
}
[Fact]
public void TokensNotEndWithNothing2()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
decoder.AddRange(tokens);
var processor = new AntipromptProcessor();
var result = processor.Add(decoder.Read());
Assert.False(result);
}
[Fact]
public void RoundTrip()
{
@ -80,7 +146,7 @@ public sealed class TokenTests
{
"Hello world",
"철수",
"😀 😃 😄 😁 😆 😅 😂 😊 😇 🙂 ",
"😀 😃 😄 😁 😆철수😅 😂 😊 😇 🙂 ",
};
var charsArr = new char[1024];
@ -99,7 +165,36 @@ public sealed class TokenTests
// Check that the input equals the output
Assert.Equal(input, output);
}
}
[Fact]
public void StreamingDecoderRoundTrip()
{
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
var strings = new[]
{
"Hello world",
"철수",
"😀 😃 😄 😁 😆철수😅 😂 😊 😇 🙂 ",
};
foreach (var input in strings)
{
decoder.Reset();
// Convert into llama tokens
var tokens = _model.NativeHandle.Tokenize(input, false, false, Encoding.UTF8);
// Add tokens to decoder
foreach (var token in tokens)
decoder.Add(token);
// llama.cpp adds a space to the start of strings, remove that
var output = decoder.Read().TrimStart(' ');
// Check that the input equals the output
Assert.Equal(input, output);
}
}
}

View File

@ -0,0 +1,66 @@
using System;
using System.Collections.Generic;
namespace LLama;
internal sealed class AntipromptProcessor
{
private int _longestAntiprompt;
private readonly List<string> _antiprompts = new();
private string? _string;
public AntipromptProcessor(IEnumerable<string>? antiprompts = null)
{
if (antiprompts != null)
SetAntiprompts(antiprompts);
}
/// <summary>
/// Add an antiprompt to the collection
/// </summary>
/// <param name="antiprompt"></param>
public void AddAntiprompt(string antiprompt)
{
_antiprompts.Add(antiprompt);
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
}
/// <summary>
/// Overwrite all current antiprompts with a new set
/// </summary>
/// <param name="antiprompts"></param>
public void SetAntiprompts(IEnumerable<string> antiprompts)
{
_antiprompts.Clear();
_antiprompts.AddRange(antiprompts);
_longestAntiprompt = 0;
foreach (var antiprompt in _antiprompts)
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
}
/// <summary>
/// Add some text and check if the buffer now ends with any antiprompt
/// </summary>
/// <param name="text"></param>
/// <returns>true if the text buffer ends with any antiprompt</returns>
public bool Add(string text)
{
_string += text;
// When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length).
// This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode
// even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances!
var maxLength = Math.Max(32, _longestAntiprompt * 4);
var trimLength = Math.Max(16, _longestAntiprompt * 2);
if (_string.Length > maxLength)
_string = _string.Substring(_string.Length - trimLength);
foreach (var antiprompt in _antiprompts)
if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture))
return true;
return false;
}
}

View File

@ -36,6 +36,7 @@ namespace LLama.Extensions
/// <param name="model">Model to use to convert tokens into bytes</param>
/// <param name="encoding">Encoding to use to convert bytes into characters</param>
/// <returns></returns>
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
where TQueries : IReadOnlyList<string>
@ -68,13 +69,6 @@ namespace LLama.Extensions
}
}
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, LLamaContext context)
where TTokens : IReadOnlyList<int>
where TQueries : IReadOnlyList<string>
{
return TokensEndsWithAnyString(tokens, queries, context.NativeHandle.ModelHandle, context.Encoding);
}
/// <summary>
/// Check if the given set of tokens ends with any of the given strings
/// </summary>
@ -83,6 +77,7 @@ namespace LLama.Extensions
/// <param name="model">Model to use to convert tokens into bytes</param>
/// <param name="encoding">Encoding to use to convert bytes into characters</param>
/// <returns></returns>
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
{

View File

@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
namespace LLama.Extensions
{
internal static class ListExtensions
{
#if NETSTANDARD2_0
public static void EnsureCapacity<T>(this List<T> list, int capacity)
{
if (list.Capacity < capacity)
list.Capacity = capacity;
}
#endif
public static void AddSpan<T>(this List<T> list, ReadOnlySpan<T> items)
{
list.EnsureCapacity(list.Count + items.Length);
for (var i = 0; i < items.Length; i++)
list.Add(items[i]);
}
}
}

View File

@ -102,9 +102,15 @@ namespace LLama
/// </summary>
/// <param name="tokens"></param>
/// <returns></returns>
[Obsolete("Use a `StreamingTokenDecoder` instead")]
public string DeTokenize(IReadOnlyList<llama_token> tokens)
{
return NativeHandle.DeTokenize(tokens, Encoding);
// Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder!
// It should be kept around for the entire time you are decoding one stream of tokens.
var decoder = new StreamingTokenDecoder(this);
decoder.AddRange(tokens);
return decoder.ToString();
}
/// <summary>

View File

@ -56,6 +56,9 @@ namespace LLama
Context.Dispose();
Context = _weights.CreateContext(Context.Params, _logger);
var decoder = new StreamingTokenDecoder(Context);
var antiprocessor = new AntipromptProcessor(inferenceParams?.AntiPrompts ?? Array.Empty<string>());
if (inferenceParams != null)
{
if (inferenceParams.TokensKeep > Context.ContextSize)
@ -64,7 +67,6 @@ namespace LLama
cancellationToken.ThrowIfCancellationRequested();
var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty<string>();
inferenceParams ??= new InferenceParams();
var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount);
@ -95,13 +97,16 @@ namespace LLama
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
lastTokens.Add(id);
yield return Context.DeTokenize(new [] { id }); //todo: not correct to return tokens one by one like this!
decoder.Add(id);
var decoded = decoder.Read();
yield return decoder.Read();
tokens.Clear();
tokens.Add(id);
// Check if any of the antiprompts have been generated
if (lastTokens.TokensEndsWithAnyString(antiprompts, Context))
if (antiprocessor.Add(decoded))
break;
// when run out of context

View File

@ -169,33 +169,6 @@ namespace LLama.Native
{
return ThrowIfDisposed().TokenToSpan(token, dest);
}
/// <summary>
/// Convert a set of tokens into a string
/// </summary>
/// <param name="tokens"></param>
/// <param name="encoding"></param>
/// <returns></returns>
public string DeTokenize(IReadOnlyList<int> tokens, Encoding encoding)
{
var chars = ArrayPool<char>.Shared.Rent(tokens.Count * 2);
try
{
var span = ThrowIfDisposed().TokensToSpan(tokens, chars.AsSpan(), encoding);
if (span.Length == 0)
return "";
unsafe
{
fixed (char* ptr = &span[0])
return new string(ptr, 0, span.Length);
}
}
finally
{
ArrayPool<char>.Shared.Return(chars);
}
}
#endregion
/// <summary>

View File

@ -129,77 +129,25 @@ namespace LLama.Native
/// If there was insufficient space in the output span this will be
/// filled with as many characters as possible, starting from the _last_ token.
/// </returns>
[Obsolete("Use a StreamingTokenDecoder instead")]
internal Span<char> TokensToSpan(IReadOnlyList<int> tokens, Span<char> dest, Encoding encoding)
{
// Rent an array to detokenize into
var tokenBytesArr = ArrayPool<byte>.Shared.Rent(16);
var decoder = new StreamingTokenDecoder(encoding, this);
// Convert all of the tokens into bytes
var bytes = new List<byte>();
foreach (var token in tokens)
decoder.Add(token);
var str = decoder.Read();
if (str.Length < dest.Length)
{
var tokenBytes = TokenToBytes(ref tokenBytesArr, token, this);
foreach (var tokenByte in tokenBytes)
bytes.Add(tokenByte);
}
// Extract a span from the list
var bytesSpan =
#if NETSTANDARD2_0
bytes.ToArray().AsSpan();
#else
CollectionsMarshal.AsSpan(bytes);
#endif
// Check how many characters these bytes represent. If there's not enough space in the
// output array we need to handle that.
var characterCount = encoding.GetCharCount(bytesSpan);
if (characterCount > dest.Length)
{
var bigChars = ArrayPool<char>.Shared.Rent(characterCount);
try
{
encoding.GetChars(bytesSpan, bigChars);
var charSlice = bigChars
.AsSpan(0, characterCount)
.Slice(characterCount - dest.Length);
charSlice.CopyTo(dest);
return dest;
}
finally
{
ArrayPool<char>.Shared.Return(bigChars);
}
//todo: handle dest span too small
throw new NotImplementedException();
str.AsSpan().CopyTo(dest);
return dest.Slice(0, str.Length);
}
else
{
var charCount = encoding.GetChars(bytes.ToArray(), dest);
return dest.Slice(0, charCount);
}
// vvv Local Functions vvv
static Span<byte> TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model)
{
// Try to get bytes, if that fails we known the length
var l = model.TokenToSpan(token, bytes);
// Array was too small, get a bigger one
if (l < 0)
{
ArrayPool<byte>.Shared.Return(bytes);
bytes = ArrayPool<byte>.Shared.Rent(-l * 2);
// Get bytes, this time it can't fail
l = model.TokenToSpan(token, bytes);
}
Debug.Assert(l >= 0);
return new Span<byte>(bytes, 0, l);
str.AsSpan().Slice(str.Length - dest.Length).CopyTo(dest);
return dest;
}
}

View File

@ -0,0 +1,169 @@
using System.Buffers;
using System.Diagnostics;
using System;
using System.Collections.Generic;
using System.Text;
using LLama.Extensions;
using LLama.Native;
namespace LLama;
/// <summary>
/// Decodes a stream of tokens into a stream of characters
/// </summary>
public sealed class StreamingTokenDecoder
{
private readonly SafeLlamaModelHandle _weights;
private readonly Decoder _decoder;
private readonly List<char> _characters = new();
/// <summary>
/// The number of decoded characters waiting to be read
/// </summary>
public int AvailableCharacters => _characters.Count;
#region constructors
/// <summary>
/// Create a new decoder
/// </summary>
/// <param name="encoding">Text encoding to use</param>
/// <param name="weights">Model weights</param>
public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights)
: this(encoding, weights.NativeHandle)
{
}
/// <summary>
/// Create a new decoder
/// </summary>
/// <param name="context">Context to retrieve encoding and model weights from</param>
public StreamingTokenDecoder(LLamaContext context)
: this(context.Encoding, context.NativeHandle)
{
}
/// <summary>
/// Create a new decoder
/// </summary>
/// <param name="encoding">Text encoding to use</param>
/// <param name="context">Context to retrieve model weights from</param>
public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context)
: this(encoding, context.ModelHandle)
{
}
/// <summary>
/// Create a new decoder
/// </summary>
/// <param name="encoding">Text encoding to use</param>
/// <param name="weights">Models weights to use</param>
public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights)
{
_weights = weights;
_decoder = encoding.GetDecoder();
}
#endregion
/// <summary>
/// Add a single token to the decoder
/// </summary>
/// <param name="token"></param>
public void Add(int token)
{
var charsArr = ArrayPool<char>.Shared.Rent(16);
var bytesArr = ArrayPool<byte>.Shared.Rent(16);
try
{
// Convert this token into bytes
var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length;
// Convert those bytes into characters
var bytesOffset = 0;
var completed = false;
while (!completed)
{
// Decode some of the bytes into the temp char buffer. Keep doing this
// until all bytes have been consumed
_decoder.Convert(
bytesArr, bytesOffset, bytesAvailable,
charsArr, 0, charsArr.Length,
false,
out var bytesUsed, out var charsUsed, out completed
);
bytesOffset += bytesUsed;
bytesAvailable -= bytesUsed;
// Add the decoded characters to the output buffer
_characters.AddSpan(charsArr.AsSpan(0, charsUsed));
}
}
finally
{
ArrayPool<char>.Shared.Return(charsArr);
ArrayPool<byte>.Shared.Return(bytesArr);
}
return;
// Converts a single token into bytes, using the `bytes` array as temporary storage.
// If the `bytes` array is too small it will get a larger one from the ArrayPool.
static Span<byte> TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model)
{
// Try to get bytes
var l = model.TokenToSpan(token, bytes);
// Negative length indicates that the output was too small. Expand it to twice that size and try again.
if (l < 0)
{
// Return the old array to the pool and get a new one
ArrayPool<byte>.Shared.Return(bytes);
bytes = ArrayPool<byte>.Shared.Rent(-l * 2);
// Get bytes, this time it can't fail
l = model.TokenToSpan(token, bytes);
}
Debug.Assert(l >= 0);
return new Span<byte>(bytes, 0, l);
}
}
/// <summary>
/// Add all tokens in the given enumerable
/// </summary>
/// <param name="tokens"></param>
public void AddRange(IEnumerable<int> tokens)
{
foreach (var item in tokens)
Add(item);
}
/// <summary>
/// Read all decoded characters and clear the buffer
/// </summary>
/// <param name="dest"></param>
public void Read(List<char> dest)
{
dest.AddRange(_characters);
_characters.Clear();
}
/// <summary>
/// Read all decoded characters as a string and clear the buffer
/// </summary>
/// <returns></returns>
public string Read()
{
return string.Join("", _characters);
}
/// <summary>
/// Set the decoder back to its initial state
/// </summary>
public void Reset()
{
_decoder.Reset();
_characters.Clear();
}
}