From 2ea2048b784a2ac4877dd9ce6686e5530db44889 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 12 Jan 2024 15:32:59 +0000 Subject: [PATCH] - Added a test for tokenizing just a new line (reproduce issue https://github.com/SciSharp/LLamaSharp/issues/430) - Properly displaying `LLamaToken` - Removed all tokenisation code in `SafeLLamaContextHandle` - just pass it all through to the `SafeLlamaModelHandle` - Improved `SafeLlamaModelHandle` tokenisation: - Renting an array, for one less allocation - Not using `&tokens[0]` to take a pointer to an array, this is redundant and doesn't work on empty arrays --- LLama.Unittest/LLamaContextTests.cs | 8 +++++ LLama/Native/LLamaToken.cs | 10 +++++- LLama/Native/SafeLLamaContextHandle.cs | 32 +----------------- LLama/Native/SafeLlamaModelHandle.cs | 47 +++++++++++++++----------- 4 files changed, 45 insertions(+), 52 deletions(-) diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index 7f1c9496..7d774f46 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -41,6 +41,14 @@ namespace LLama.Unittest Assert.Equal(new LLamaToken[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens); } + [Fact] + public void TokenizeNewline() + { + var tokens = _context.Tokenize("\n"); + + Assert.Equal(new LLamaToken[] { 1, 29871, 13 }, tokens); + } + [Fact] public void TokenizeWithoutBOS() { diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs index 0bc48585..128d9f58 100644 --- a/LLama/Native/LLamaToken.cs +++ b/LLama/Native/LLamaToken.cs @@ -1,4 +1,5 @@ -using System.Runtime.InteropServices; +using System.Diagnostics; +using System.Runtime.InteropServices; namespace LLama.Native; @@ -6,6 +7,7 @@ namespace LLama.Native; /// A single token /// [StructLayout(LayoutKind.Sequential)] +[DebuggerDisplay("Value")] public readonly record struct LLamaToken { /// @@ -35,4 +37,10 @@ public readonly record struct LLamaToken /// /// public static implicit operator LLamaToken(int value) => new(value); + + /// + public override string ToString() + { + return Value.ToString(); + } } \ No newline at end of file diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 17fa13cf..3f303123 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -155,37 +155,7 @@ namespace LLama.Native /// public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { - ThrowIfDisposed(); - - if (string.IsNullOrEmpty(text) && !add_bos) - return Array.Empty(); - - // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't - // possibly be more than this. - var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0); - - // "Rent" an array to write results into (avoiding an allocation of a large array) - var temporaryArray = ArrayPool.Shared.Rent(count); - try - { - // Do the actual conversion - var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos, special); - if (n < 0) - { - throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + - "specify the encoding."); - } - - // Copy the results from the rented into an array which is exactly the right size - var result = new LLamaToken[n]; - Array.ConstrainedCopy(temporaryArray, 0, result, 0, n); - - return result; - } - finally - { - ArrayPool.Shared.Return(temporaryArray); - } + return ThrowIfDisposed().Tokenize(text, add_bos, special, encoding); } /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 47ffb4dd..8ffa2be3 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -1,4 +1,5 @@ using System; +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.Runtime.InteropServices; @@ -172,34 +173,40 @@ namespace LLama.Native /// public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { + // Early exit if there's no work to do + if (text == "" && !add_bos) + return Array.Empty(); + // Convert string to bytes, adding one extra byte to the end (null terminator) var bytesCount = encoding.GetByteCount(text); - var bytes = new byte[bytesCount + 1]; - unsafe + var bytes = ArrayPool.Shared.Rent(bytesCount + 1); + try { - fixed (char* charPtr = text) - fixed (byte* bytePtr = &bytes[0]) + unsafe { - encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length); - } - } - - unsafe - { - fixed (byte* bytesPtr = &bytes[0]) - { - // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) - var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, add_bos, special); - - // Tokenize again, this time outputting into an array of exactly the right size - var tokens = new LLamaToken[count]; - fixed (LLamaToken* tokensPtr = &tokens[0]) + fixed (char* textPtr = text) + fixed (byte* bytesPtr = bytes) { - NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special); - return tokens; + // Convert text into bytes + encoding.GetBytes(textPtr, text.Length, bytesPtr, bytes.Length); + + // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) + var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, add_bos, special); + + // Tokenize again, this time outputting into an array of exactly the right size + var tokens = new LLamaToken[count]; + fixed (LLamaToken* tokensPtr = tokens) + { + NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special); + return tokens; + } } } } + finally + { + ArrayPool.Shared.Return(bytes, true); + } } #endregion