Fixed decoding of large tokens (over 16 bytes) in streaming text decoder
This commit is contained in:
parent
54dffe7e03
commit
98635a0d5a
|
@ -0,0 +1,53 @@
|
|||
using System.Text;
|
||||
using LLama.Common;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
namespace LLama.Unittest;
|
||||
|
||||
public class StreamingTextDecoderTests
|
||||
: IDisposable
|
||||
{
|
||||
private readonly LLamaWeights _model;
|
||||
private readonly ITestOutputHelper _testOutputHelper;
|
||||
private readonly ModelParams _params;
|
||||
|
||||
public StreamingTextDecoderTests(ITestOutputHelper testOutputHelper)
|
||||
{
|
||||
_testOutputHelper = testOutputHelper;
|
||||
_params = new ModelParams(Constants.ModelPath);
|
||||
_model = LLamaWeights.LoadFromFile(_params);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_model.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void DecodesSimpleText()
|
||||
{
|
||||
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
|
||||
|
||||
const string text = "The cat sat on the mat";
|
||||
var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8);
|
||||
|
||||
foreach (var lLamaToken in tokens)
|
||||
decoder.Add(lLamaToken);
|
||||
|
||||
Assert.Equal(text, decoder.Read().Trim());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void DecodesComplexText()
|
||||
{
|
||||
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
|
||||
|
||||
const string text = "猫坐在垫子上 😀🤨🤐😏";
|
||||
var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8);
|
||||
|
||||
foreach (var lLamaToken in tokens)
|
||||
decoder.Add(lLamaToken);
|
||||
|
||||
Assert.Equal(text, decoder.Read().Trim());
|
||||
}
|
||||
}
|
|
@ -194,7 +194,7 @@ namespace LLama.Native
|
|||
/// <param name="token">Token to decode</param>
|
||||
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
|
||||
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
|
||||
public int TokenToSpan(LLamaToken token, Span<byte> dest)
|
||||
public uint TokenToSpan(LLamaToken token, Span<byte> dest)
|
||||
{
|
||||
return ThrowIfDisposed().TokenToSpan(token, dest);
|
||||
}
|
||||
|
|
|
@ -126,10 +126,10 @@ namespace LLama.Native
|
|||
/// <param name="token">Token to decode</param>
|
||||
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
|
||||
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
|
||||
public int TokenToSpan(LLamaToken token, Span<byte> dest)
|
||||
public uint TokenToSpan(LLamaToken token, Span<byte> dest)
|
||||
{
|
||||
var length = NativeApi.llama_token_to_piece(this, token, dest);
|
||||
return Math.Abs(length);
|
||||
return (uint)Math.Abs(length);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -113,19 +113,19 @@ namespace LLama
|
|||
// Try to get bytes
|
||||
var l = model.TokenToSpan(token, bytes);
|
||||
|
||||
// Negative length indicates that the output was too small. Expand it to twice that size and try again.
|
||||
if (l < 0)
|
||||
// Check if the length was larger than the buffer. If so expand the buffer and try again
|
||||
if (l > bytes.Length)
|
||||
{
|
||||
// Return the old array to the pool and get a new one
|
||||
ArrayPool<byte>.Shared.Return(bytes);
|
||||
bytes = ArrayPool<byte>.Shared.Rent(-l * 2);
|
||||
bytes = ArrayPool<byte>.Shared.Rent((int)(l * 2));
|
||||
|
||||
// Get bytes, this time it can't fail
|
||||
l = model.TokenToSpan(token, bytes);
|
||||
}
|
||||
|
||||
Debug.Assert(l >= 0);
|
||||
return new Span<byte>(bytes, 0, l);
|
||||
Debug.Assert(l <= bytes.Length);
|
||||
return new Span<byte>(bytes, 0, (int)l);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue