- Added a test for tokenizing just a new line (reproduce issue https://github.com/SciSharp/LLamaSharp/issues/430)

- Properly displaying `LLamaToken`
 - Removed all tokenisation code in `SafeLLamaContextHandle` - just pass it all through to the `SafeLlamaModelHandle`
 - Improved `SafeLlamaModelHandle` tokenisation:
   - Renting an array, for one less allocation
   - Not using `&tokens[0]` to take a pointer to an array, this is redundant and doesn't work on empty arrays
This commit is contained in:
Martin Evans 2024-01-12 15:32:59 +00:00
parent ba477b83a0
commit 2ea2048b78
4 changed files with 45 additions and 52 deletions

View File

@ -41,6 +41,14 @@ namespace LLama.Unittest
Assert.Equal(new LLamaToken[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens);
}
[Fact]
public void TokenizeNewline()
{
var tokens = _context.Tokenize("\n");
Assert.Equal(new LLamaToken[] { 1, 29871, 13 }, tokens);
}
[Fact]
public void TokenizeWithoutBOS()
{

View File

@ -1,4 +1,5 @@
using System.Runtime.InteropServices;
using System.Diagnostics;
using System.Runtime.InteropServices;
namespace LLama.Native;
@ -6,6 +7,7 @@ namespace LLama.Native;
/// A single token
/// </summary>
[StructLayout(LayoutKind.Sequential)]
[DebuggerDisplay("Value")]
public readonly record struct LLamaToken
{
/// <summary>
@ -35,4 +37,10 @@ public readonly record struct LLamaToken
/// <param name="value"></param>
/// <returns></returns>
public static implicit operator LLamaToken(int value) => new(value);
/// <inheritdoc />
public override string ToString()
{
return Value.ToString();
}
}

View File

@ -155,37 +155,7 @@ namespace LLama.Native
/// <exception cref="RuntimeError"></exception>
public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{
ThrowIfDisposed();
if (string.IsNullOrEmpty(text) && !add_bos)
return Array.Empty<LLamaToken>();
// Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't
// possibly be more than this.
var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0);
// "Rent" an array to write results into (avoiding an allocation of a large array)
var temporaryArray = ArrayPool<LLamaToken>.Shared.Rent(count);
try
{
// Do the actual conversion
var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos, special);
if (n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
"specify the encoding.");
}
// Copy the results from the rented into an array which is exactly the right size
var result = new LLamaToken[n];
Array.ConstrainedCopy(temporaryArray, 0, result, 0, n);
return result;
}
finally
{
ArrayPool<LLamaToken>.Shared.Return(temporaryArray);
}
return ThrowIfDisposed().Tokenize(text, add_bos, special, encoding);
}
/// <summary>

View File

@ -1,4 +1,5 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.InteropServices;
@ -172,34 +173,40 @@ namespace LLama.Native
/// <returns></returns>
public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{
// Early exit if there's no work to do
if (text == "" && !add_bos)
return Array.Empty<LLamaToken>();
// Convert string to bytes, adding one extra byte to the end (null terminator)
var bytesCount = encoding.GetByteCount(text);
var bytes = new byte[bytesCount + 1];
unsafe
var bytes = ArrayPool<byte>.Shared.Rent(bytesCount + 1);
try
{
fixed (char* charPtr = text)
fixed (byte* bytePtr = &bytes[0])
unsafe
{
encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length);
}
}
unsafe
{
fixed (byte* bytesPtr = &bytes[0])
{
// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, add_bos, special);
// Tokenize again, this time outputting into an array of exactly the right size
var tokens = new LLamaToken[count];
fixed (LLamaToken* tokensPtr = &tokens[0])
fixed (char* textPtr = text)
fixed (byte* bytesPtr = bytes)
{
NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special);
return tokens;
// Convert text into bytes
encoding.GetBytes(textPtr, text.Length, bytesPtr, bytes.Length);
// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, add_bos, special);
// Tokenize again, this time outputting into an array of exactly the right size
var tokens = new LLamaToken[count];
fixed (LLamaToken* tokensPtr = tokens)
{
NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special);
return tokens;
}
}
}
}
finally
{
ArrayPool<byte>.Shared.Return(bytes, true);
}
}
#endregion