diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs index cced0987..6ca9ebd4 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaModel.cs @@ -7,6 +7,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; +using System.Text; namespace LLama { @@ -463,12 +464,12 @@ namespace LLama /// The utf-8 encoded string to tokenize. /// A list of tokens. /// If the tokenization failed. - public List Tokenize(string text) + public List Tokenize(string text, string encoding = "UTF-8") { Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); var n_ctx = NativeApi.llama_n_ctx(_ctx); var tokens = new llama_token[n_ctx]; - var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true); + var n_tokens = NativeApi.llama_tokenize(_ctx, text, Encoding.GetEncoding(encoding), tokens, n_ctx, true); if (n_tokens < 0) { throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}"); diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 71bfcc6b..07ac4efe 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -176,8 +176,27 @@ namespace LLama.Native /// /// /// - [DllImport(libraryName)] - public static extern int llama_tokenize(SafeLLamaContextHandle ctx, string text, llama_token[] tokens, int n_max_tokens, bool add_bos); + public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos) + { + var bytes = encoding.GetBytes(text); + sbyte[] data = new sbyte[bytes.Length]; + for(int i = 0; i < bytes.Length; i++) + { + data[i] = (sbyte)bytes[i]; + //if (bytes[i] < 128) + //{ + // data[i] = (sbyte)bytes[i]; + //} + //else + //{ + // data[i] = (sbyte)(~((sbyte)(~bytes[i] + 1)) + 1); + //} + } + return llama_tokenize_native(ctx, data, tokens, n_max_tokens, add_bos); + } + + [DllImport(libraryName, EntryPoint = "llama_tokenize")] + public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, sbyte[] text, llama_token[] tokens, int n_max_tokens, bool add_bos); [DllImport(libraryName)] public static extern int llama_n_vocab(SafeLLamaContextHandle ctx); diff --git a/LLama/Utils.cs b/LLama/Utils.cs index 5c04a30f..911ad61e 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -52,11 +52,12 @@ namespace LLama return ctx; } - public static List llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encoding) + public static List llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encodingName) { - var cnt = Encoding.GetEncoding(encoding).GetByteCount(text); + var encoding = Encoding.GetEncoding(encodingName); + var cnt = encoding.GetByteCount(text); llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)]; - int n = NativeApi.llama_tokenize(ctx, text, res, res.Length, add_bos); + int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos); if(n < 0) { throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +