From 51d4411a58a0bc76b80e4dfa41dd72aea72164e5 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Mon, 23 Oct 2023 00:33:50 +0100 Subject: [PATCH] Added two new classes for detokenization tasks: - `AntipromptProcessor` accepts chunks of text and returns a value indicating if any antiprompt has been detected. - `StreamingTokenDecoder` decodes tokens into text, maintaining some internal state to handle single characters which are encoded as multiple tokens. Added tests for these classes and updated StatelessExecutor to use them. Removed most DeTokenize methods, marked the rest as obsolete (should always use a `StreamingTokenDecoder`). --- LLama.Unittest/BeamTests.cs | 5 +- LLama.Unittest/TokenTests.cs | 99 +++++++++++- LLama/AntipromptProcessor.cs | 66 ++++++++ LLama/Extensions/IReadOnlyListExtensions.cs | 9 +- LLama/Extensions/ListExtensions.cs | 24 +++ LLama/LLamaContext.cs | 8 +- LLama/LLamaStatelessExecutor.cs | 11 +- LLama/Native/SafeLLamaContextHandle.cs | 27 ---- LLama/Native/SafeLlamaModelHandle.cs | 74 ++------- LLama/StreamingTokenDecoder.cs | 169 ++++++++++++++++++++ 10 files changed, 387 insertions(+), 105 deletions(-) create mode 100644 LLama/AntipromptProcessor.cs create mode 100644 LLama/Extensions/ListExtensions.cs create mode 100644 LLama/StreamingTokenDecoder.cs diff --git a/LLama.Unittest/BeamTests.cs b/LLama.Unittest/BeamTests.cs index f8d5cf01..3014894e 100644 --- a/LLama.Unittest/BeamTests.cs +++ b/LLama.Unittest/BeamTests.cs @@ -32,13 +32,14 @@ public sealed class BeamTests { const int num_beams = 2; const int n_predict = 3; + const string prompt = "The cat sat on"; var context = _model.CreateContext(_params); var result = new StringBuilder(); - var initial_tokens = context.Tokenize("The cat sat on"); - result.Append(context.DeTokenize(initial_tokens.ToArray())); + var initial_tokens = context.Tokenize(prompt); + result.Append(prompt); context.Eval(initial_tokens, 0); NativeApi.llama_beam_search(context.NativeHandle, (data, state) => diff --git a/LLama.Unittest/TokenTests.cs b/LLama.Unittest/TokenTests.cs index 383428af..e39df5f4 100644 --- a/LLama.Unittest/TokenTests.cs +++ b/LLama.Unittest/TokenTests.cs @@ -73,6 +73,72 @@ public sealed class TokenTests Assert.False(result); } + [Fact] + public void TokensEndWith2() + { + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); + + var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); + decoder.AddRange(tokens); + + var processor = new AntipromptProcessor(new[] + { + "a fish", + "the mat", + "this is an improbably long query to be using for this method" + }); + var result = processor.Add(decoder.Read()); + + Assert.True(result); + } + + [Fact] + public void TokensEndSubstring2() + { + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); + + var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); + decoder.AddRange(tokens); + + var processor = new AntipromptProcessor(new[] { "at" }); + var result = processor.Add(decoder.Read()); + + Assert.True(result); + } + + [Fact] + public void TokensNotEndWith2() + { + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); + + var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); + decoder.AddRange(tokens); + + var processor = new AntipromptProcessor(new[] + { + "a fish", + "The cat sat on the edge of the ma", + "this is an improbably long query to be using for this method" + }); + var result = processor.Add(decoder.Read()); + + Assert.False(result); + } + + [Fact] + public void TokensNotEndWithNothing2() + { + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); + + var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); + decoder.AddRange(tokens); + + var processor = new AntipromptProcessor(); + var result = processor.Add(decoder.Read()); + + Assert.False(result); + } + [Fact] public void RoundTrip() { @@ -80,7 +146,7 @@ public sealed class TokenTests { "Hello world", "철수", - "πŸ˜€ πŸ˜ƒ πŸ˜„ 😁 πŸ˜† πŸ˜… πŸ˜‚ 😊 πŸ˜‡ πŸ™‚ ", + "πŸ˜€ πŸ˜ƒ πŸ˜„ 😁 πŸ˜†μ² μˆ˜πŸ˜… πŸ˜‚ 😊 πŸ˜‡ πŸ™‚ ", }; var charsArr = new char[1024]; @@ -99,7 +165,36 @@ public sealed class TokenTests // Check that the input equals the output Assert.Equal(input, output); } + } - + [Fact] + public void StreamingDecoderRoundTrip() + { + var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); + + var strings = new[] + { + "Hello world", + "철수", + "πŸ˜€ πŸ˜ƒ πŸ˜„ 😁 πŸ˜†μ² μˆ˜πŸ˜… πŸ˜‚ 😊 πŸ˜‡ πŸ™‚ ", + }; + + foreach (var input in strings) + { + decoder.Reset(); + + // Convert into llama tokens + var tokens = _model.NativeHandle.Tokenize(input, false, false, Encoding.UTF8); + + // Add tokens to decoder + foreach (var token in tokens) + decoder.Add(token); + + // llama.cpp adds a space to the start of strings, remove that + var output = decoder.Read().TrimStart(' '); + + // Check that the input equals the output + Assert.Equal(input, output); + } } } \ No newline at end of file diff --git a/LLama/AntipromptProcessor.cs b/LLama/AntipromptProcessor.cs new file mode 100644 index 00000000..4d969cea --- /dev/null +++ b/LLama/AntipromptProcessor.cs @@ -0,0 +1,66 @@ +ο»Ώusing System; +using System.Collections.Generic; + +namespace LLama; + +internal sealed class AntipromptProcessor +{ + private int _longestAntiprompt; + private readonly List _antiprompts = new(); + + private string? _string; + + public AntipromptProcessor(IEnumerable? antiprompts = null) + { + if (antiprompts != null) + SetAntiprompts(antiprompts); + } + + /// + /// Add an antiprompt to the collection + /// + /// + public void AddAntiprompt(string antiprompt) + { + _antiprompts.Add(antiprompt); + _longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length); + } + + /// + /// Overwrite all current antiprompts with a new set + /// + /// + public void SetAntiprompts(IEnumerable antiprompts) + { + _antiprompts.Clear(); + _antiprompts.AddRange(antiprompts); + + _longestAntiprompt = 0; + foreach (var antiprompt in _antiprompts) + _longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length); + } + + /// + /// Add some text and check if the buffer now ends with any antiprompt + /// + /// + /// true if the text buffer ends with any antiprompt + public bool Add(string text) + { + _string += text; + + // When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length). + // This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode + // even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances! + var maxLength = Math.Max(32, _longestAntiprompt * 4); + var trimLength = Math.Max(16, _longestAntiprompt * 2); + if (_string.Length > maxLength) + _string = _string.Substring(_string.Length - trimLength); + + foreach (var antiprompt in _antiprompts) + if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture)) + return true; + + return false; + } +} \ No newline at end of file diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs index 4d1c6f09..7a3473b7 100644 --- a/LLama/Extensions/IReadOnlyListExtensions.cs +++ b/LLama/Extensions/IReadOnlyListExtensions.cs @@ -36,6 +36,7 @@ namespace LLama.Extensions /// Model to use to convert tokens into bytes /// Encoding to use to convert bytes into characters /// + [Obsolete("Use an Antiprompt processor instead")] internal static bool TokensEndsWithAnyString(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding) where TTokens : IReadOnlyList where TQueries : IReadOnlyList @@ -68,13 +69,6 @@ namespace LLama.Extensions } } - internal static bool TokensEndsWithAnyString(this TTokens tokens, TQueries? queries, LLamaContext context) - where TTokens : IReadOnlyList - where TQueries : IReadOnlyList - { - return TokensEndsWithAnyString(tokens, queries, context.NativeHandle.ModelHandle, context.Encoding); - } - /// /// Check if the given set of tokens ends with any of the given strings /// @@ -83,6 +77,7 @@ namespace LLama.Extensions /// Model to use to convert tokens into bytes /// Encoding to use to convert bytes into characters /// + [Obsolete("Use an Antiprompt processor instead")] internal static bool TokensEndsWithAnyString(this TTokens tokens, IList? queries, SafeLlamaModelHandle model, Encoding encoding) where TTokens : IReadOnlyList { diff --git a/LLama/Extensions/ListExtensions.cs b/LLama/Extensions/ListExtensions.cs new file mode 100644 index 00000000..11a1d4f0 --- /dev/null +++ b/LLama/Extensions/ListExtensions.cs @@ -0,0 +1,24 @@ +ο»Ώusing System; +using System.Collections.Generic; + +namespace LLama.Extensions +{ + internal static class ListExtensions + { +#if NETSTANDARD2_0 + public static void EnsureCapacity(this List list, int capacity) + { + if (list.Capacity < capacity) + list.Capacity = capacity; + } +#endif + + public static void AddSpan(this List list, ReadOnlySpan items) + { + list.EnsureCapacity(list.Count + items.Length); + + for (var i = 0; i < items.Length; i++) + list.Add(items[i]); + } + } +} diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 7edf62c4..46b0ae3f 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -102,9 +102,15 @@ namespace LLama /// /// /// + [Obsolete("Use a `StreamingTokenDecoder` instead")] public string DeTokenize(IReadOnlyList tokens) { - return NativeHandle.DeTokenize(tokens, Encoding); + // Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder! + // It should be kept around for the entire time you are decoding one stream of tokens. + + var decoder = new StreamingTokenDecoder(this); + decoder.AddRange(tokens); + return decoder.ToString(); } /// diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 457b3894..f5cdfb2e 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -56,6 +56,9 @@ namespace LLama Context.Dispose(); Context = _weights.CreateContext(Context.Params, _logger); + var decoder = new StreamingTokenDecoder(Context); + var antiprocessor = new AntipromptProcessor(inferenceParams?.AntiPrompts ?? Array.Empty()); + if (inferenceParams != null) { if (inferenceParams.TokensKeep > Context.ContextSize) @@ -64,7 +67,6 @@ namespace LLama cancellationToken.ThrowIfCancellationRequested(); - var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty(); inferenceParams ??= new InferenceParams(); var lastTokens = new List(inferenceParams.RepeatLastTokensCount); @@ -95,13 +97,16 @@ namespace LLama inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); lastTokens.Add(id); - yield return Context.DeTokenize(new [] { id }); //todo: not correct to return tokens one by one like this! + + decoder.Add(id); + var decoded = decoder.Read(); + yield return decoder.Read(); tokens.Clear(); tokens.Add(id); // Check if any of the antiprompts have been generated - if (lastTokens.TokensEndsWithAnyString(antiprompts, Context)) + if (antiprocessor.Add(decoded)) break; // when run out of context diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 160b8cc8..7fb5edf7 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -169,33 +169,6 @@ namespace LLama.Native { return ThrowIfDisposed().TokenToSpan(token, dest); } - - /// - /// Convert a set of tokens into a string - /// - /// - /// - /// - public string DeTokenize(IReadOnlyList tokens, Encoding encoding) - { - var chars = ArrayPool.Shared.Rent(tokens.Count * 2); - try - { - var span = ThrowIfDisposed().TokensToSpan(tokens, chars.AsSpan(), encoding); - if (span.Length == 0) - return ""; - - unsafe - { - fixed (char* ptr = &span[0]) - return new string(ptr, 0, span.Length); - } - } - finally - { - ArrayPool.Shared.Return(chars); - } - } #endregion /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 7afcc3af..94ddca2d 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -129,77 +129,25 @@ namespace LLama.Native /// If there was insufficient space in the output span this will be /// filled with as many characters as possible, starting from the _last_ token. /// + [Obsolete("Use a StreamingTokenDecoder instead")] internal Span TokensToSpan(IReadOnlyList tokens, Span dest, Encoding encoding) { - // Rent an array to detokenize into - var tokenBytesArr = ArrayPool.Shared.Rent(16); + var decoder = new StreamingTokenDecoder(encoding, this); - // Convert all of the tokens into bytes - var bytes = new List(); foreach (var token in tokens) + decoder.Add(token); + + var str = decoder.Read(); + + if (str.Length < dest.Length) { - var tokenBytes = TokenToBytes(ref tokenBytesArr, token, this); - foreach (var tokenByte in tokenBytes) - bytes.Add(tokenByte); - } - - // Extract a span from the list - var bytesSpan = -#if NETSTANDARD2_0 - bytes.ToArray().AsSpan(); -#else - CollectionsMarshal.AsSpan(bytes); -#endif - - // Check how many characters these bytes represent. If there's not enough space in the - // output array we need to handle that. - var characterCount = encoding.GetCharCount(bytesSpan); - if (characterCount > dest.Length) - { - var bigChars = ArrayPool.Shared.Rent(characterCount); - try - { - encoding.GetChars(bytesSpan, bigChars); - var charSlice = bigChars - .AsSpan(0, characterCount) - .Slice(characterCount - dest.Length); - - charSlice.CopyTo(dest); - return dest; - } - finally - { - ArrayPool.Shared.Return(bigChars); - } - - //todo: handle dest span too small - throw new NotImplementedException(); + str.AsSpan().CopyTo(dest); + return dest.Slice(0, str.Length); } else { - var charCount = encoding.GetChars(bytes.ToArray(), dest); - return dest.Slice(0, charCount); - } - - // vvv Local Functions vvv - - static Span TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model) - { - // Try to get bytes, if that fails we known the length - var l = model.TokenToSpan(token, bytes); - - // Array was too small, get a bigger one - if (l < 0) - { - ArrayPool.Shared.Return(bytes); - bytes = ArrayPool.Shared.Rent(-l * 2); - - // Get bytes, this time it can't fail - l = model.TokenToSpan(token, bytes); - } - - Debug.Assert(l >= 0); - return new Span(bytes, 0, l); + str.AsSpan().Slice(str.Length - dest.Length).CopyTo(dest); + return dest; } } diff --git a/LLama/StreamingTokenDecoder.cs b/LLama/StreamingTokenDecoder.cs new file mode 100644 index 00000000..fc459199 --- /dev/null +++ b/LLama/StreamingTokenDecoder.cs @@ -0,0 +1,169 @@ +ο»Ώusing System.Buffers; +using System.Diagnostics; +using System; +using System.Collections.Generic; +using System.Text; +using LLama.Extensions; +using LLama.Native; + +namespace LLama; + +/// +/// Decodes a stream of tokens into a stream of characters +/// +public sealed class StreamingTokenDecoder +{ + private readonly SafeLlamaModelHandle _weights; + private readonly Decoder _decoder; + + private readonly List _characters = new(); + + /// + /// The number of decoded characters waiting to be read + /// + public int AvailableCharacters => _characters.Count; + + #region constructors + /// + /// Create a new decoder + /// + /// Text encoding to use + /// Model weights + public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights) + : this(encoding, weights.NativeHandle) + { + } + + /// + /// Create a new decoder + /// + /// Context to retrieve encoding and model weights from + public StreamingTokenDecoder(LLamaContext context) + : this(context.Encoding, context.NativeHandle) + { + } + + /// + /// Create a new decoder + /// + /// Text encoding to use + /// Context to retrieve model weights from + public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context) + : this(encoding, context.ModelHandle) + { + } + + /// + /// Create a new decoder + /// + /// Text encoding to use + /// Models weights to use + public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights) + { + _weights = weights; + _decoder = encoding.GetDecoder(); + } + #endregion + + /// + /// Add a single token to the decoder + /// + /// + public void Add(int token) + { + var charsArr = ArrayPool.Shared.Rent(16); + var bytesArr = ArrayPool.Shared.Rent(16); + try + { + // Convert this token into bytes + var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length; + + // Convert those bytes into characters + var bytesOffset = 0; + var completed = false; + while (!completed) + { + // Decode some of the bytes into the temp char buffer. Keep doing this + // until all bytes have been consumed + _decoder.Convert( + bytesArr, bytesOffset, bytesAvailable, + charsArr, 0, charsArr.Length, + false, + out var bytesUsed, out var charsUsed, out completed + ); + bytesOffset += bytesUsed; + bytesAvailable -= bytesUsed; + + // Add the decoded characters to the output buffer + _characters.AddSpan(charsArr.AsSpan(0, charsUsed)); + } + } + finally + { + ArrayPool.Shared.Return(charsArr); + ArrayPool.Shared.Return(bytesArr); + } + + return; + + // Converts a single token into bytes, using the `bytes` array as temporary storage. + // If the `bytes` array is too small it will get a larger one from the ArrayPool. + static Span TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model) + { + // Try to get bytes + var l = model.TokenToSpan(token, bytes); + + // Negative length indicates that the output was too small. Expand it to twice that size and try again. + if (l < 0) + { + // Return the old array to the pool and get a new one + ArrayPool.Shared.Return(bytes); + bytes = ArrayPool.Shared.Rent(-l * 2); + + // Get bytes, this time it can't fail + l = model.TokenToSpan(token, bytes); + } + + Debug.Assert(l >= 0); + return new Span(bytes, 0, l); + } + } + + /// + /// Add all tokens in the given enumerable + /// + /// + public void AddRange(IEnumerable tokens) + { + foreach (var item in tokens) + Add(item); + } + + /// + /// Read all decoded characters and clear the buffer + /// + /// + public void Read(List dest) + { + dest.AddRange(_characters); + _characters.Clear(); + } + + /// + /// Read all decoded characters as a string and clear the buffer + /// + /// + public string Read() + { + return string.Join("", _characters); + } + + /// + /// Set the decoder back to its initial state + /// + public void Reset() + { + _decoder.Reset(); + _characters.Clear(); + } +} \ No newline at end of file