From 98635a0d5a761dc4d88b3074d38cb41fefec9cb1 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 9 Jan 2024 17:18:27 +0000 Subject: [PATCH] Fixed decoding of large tokens (over 16 bytes) in streaming text decoder --- LLama.Unittest/StreamingTextDecoderTests.cs | 53 +++++++++++++++++++++ LLama/Native/SafeLLamaContextHandle.cs | 2 +- LLama/Native/SafeLlamaModelHandle.cs | 4 +- LLama/StreamingTokenDecoder.cs | 10 ++-- 4 files changed, 61 insertions(+), 8 deletions(-) create mode 100644 LLama.Unittest/StreamingTextDecoderTests.cs diff --git a/LLama.Unittest/StreamingTextDecoderTests.cs b/LLama.Unittest/StreamingTextDecoderTests.cs new file mode 100644 index 00000000..680ca076 --- /dev/null +++ b/LLama.Unittest/StreamingTextDecoderTests.cs @@ -0,0 +1,53 @@ +using System.Text; +using LLama.Common; +using Xunit.Abstractions; + +namespace LLama.Unittest; + +public class StreamingTextDecoderTests + : IDisposable +{ + private readonly LLamaWeights _model; + private readonly ITestOutputHelper _testOutputHelper; + private readonly ModelParams _params; + + public StreamingTextDecoderTests(ITestOutputHelper testOutputHelper) + { + _testOutputHelper = testOutputHelper; + _params = new ModelParams(Constants.ModelPath); + _model = LLamaWeights.LoadFromFile(_params); + } + + public void Dispose() + { + _model.Dispose(); + } + + [Fact] + public void DecodesSimpleText() + { + var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); + + const string text = "The cat sat on the mat"; + var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8); + + foreach (var lLamaToken in tokens) + decoder.Add(lLamaToken); + + Assert.Equal(text, decoder.Read().Trim()); + } + + [Fact] + public void DecodesComplexText() + { + var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); + + const string text = "猫坐在垫子上 😀🤨🤐😏"; + var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8); + + foreach (var lLamaToken in tokens) + decoder.Add(lLamaToken); + + Assert.Equal(text, decoder.Read().Trim()); + } +} \ No newline at end of file diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index c6957721..17fa13cf 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -194,7 +194,7 @@ namespace LLama.Native /// Token to decode /// A span to attempt to write into. If this is too small nothing will be written /// The size of this token. **nothing will be written** if this is larger than `dest` - public int TokenToSpan(LLamaToken token, Span dest) + public uint TokenToSpan(LLamaToken token, Span dest) { return ThrowIfDisposed().TokenToSpan(token, dest); } diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index d4fe2d71..47ffb4dd 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -126,10 +126,10 @@ namespace LLama.Native /// Token to decode /// A span to attempt to write into. If this is too small nothing will be written /// The size of this token. **nothing will be written** if this is larger than `dest` - public int TokenToSpan(LLamaToken token, Span dest) + public uint TokenToSpan(LLamaToken token, Span dest) { var length = NativeApi.llama_token_to_piece(this, token, dest); - return Math.Abs(length); + return (uint)Math.Abs(length); } /// diff --git a/LLama/StreamingTokenDecoder.cs b/LLama/StreamingTokenDecoder.cs index a66d5c9e..4c1ea58d 100644 --- a/LLama/StreamingTokenDecoder.cs +++ b/LLama/StreamingTokenDecoder.cs @@ -113,19 +113,19 @@ namespace LLama // Try to get bytes var l = model.TokenToSpan(token, bytes); - // Negative length indicates that the output was too small. Expand it to twice that size and try again. - if (l < 0) + // Check if the length was larger than the buffer. If so expand the buffer and try again + if (l > bytes.Length) { // Return the old array to the pool and get a new one ArrayPool.Shared.Return(bytes); - bytes = ArrayPool.Shared.Rent(-l * 2); + bytes = ArrayPool.Shared.Rent((int)(l * 2)); // Get bytes, this time it can't fail l = model.TokenToSpan(token, bytes); } - Debug.Assert(l >= 0); - return new Span(bytes, 0, l); + Debug.Assert(l <= bytes.Length); + return new Span(bytes, 0, (int)l); } }