Fixed decoding of large tokens (over 16 bytes) in streaming text decoder

This commit is contained in:
Martin Evans 2024-01-09 17:18:27 +00:00
parent 54dffe7e03
commit 98635a0d5a
4 changed files with 61 additions and 8 deletions

View File

@ -0,0 +1,53 @@
using System.Text;
using LLama.Common;
using Xunit.Abstractions;
namespace LLama.Unittest;
public class StreamingTextDecoderTests
: IDisposable
{
private readonly LLamaWeights _model;
private readonly ITestOutputHelper _testOutputHelper;
private readonly ModelParams _params;
public StreamingTextDecoderTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
_params = new ModelParams(Constants.ModelPath);
_model = LLamaWeights.LoadFromFile(_params);
}
public void Dispose()
{
_model.Dispose();
}
[Fact]
public void DecodesSimpleText()
{
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
const string text = "The cat sat on the mat";
var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8);
foreach (var lLamaToken in tokens)
decoder.Add(lLamaToken);
Assert.Equal(text, decoder.Read().Trim());
}
[Fact]
public void DecodesComplexText()
{
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
const string text = "猫坐在垫子上 😀🤨🤐😏";
var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8);
foreach (var lLamaToken in tokens)
decoder.Add(lLamaToken);
Assert.Equal(text, decoder.Read().Trim());
}
}

View File

@ -194,7 +194,7 @@ namespace LLama.Native
/// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(LLamaToken token, Span<byte> dest)
public uint TokenToSpan(LLamaToken token, Span<byte> dest)
{
return ThrowIfDisposed().TokenToSpan(token, dest);
}

View File

@ -126,10 +126,10 @@ namespace LLama.Native
/// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(LLamaToken token, Span<byte> dest)
public uint TokenToSpan(LLamaToken token, Span<byte> dest)
{
var length = NativeApi.llama_token_to_piece(this, token, dest);
return Math.Abs(length);
return (uint)Math.Abs(length);
}
/// <summary>

View File

@ -113,19 +113,19 @@ namespace LLama
// Try to get bytes
var l = model.TokenToSpan(token, bytes);
// Negative length indicates that the output was too small. Expand it to twice that size and try again.
if (l < 0)
// Check if the length was larger than the buffer. If so expand the buffer and try again
if (l > bytes.Length)
{
// Return the old array to the pool and get a new one
ArrayPool<byte>.Shared.Return(bytes);
bytes = ArrayPool<byte>.Shared.Rent(-l * 2);
bytes = ArrayPool<byte>.Shared.Rent((int)(l * 2));
// Get bytes, this time it can't fail
l = model.TokenToSpan(token, bytes);
}
Debug.Assert(l >= 0);
return new Span<byte>(bytes, 0, l);
Debug.Assert(l <= bytes.Length);
return new Span<byte>(bytes, 0, (int)l);
}
}