diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs
index cced0987..6ca9ebd4 100644
--- a/LLama/LLamaModel.cs
+++ b/LLama/LLamaModel.cs
@@ -7,6 +7,7 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
+using System.Text;
namespace LLama
{
@@ -463,12 +464,12 @@ namespace LLama
/// The utf-8 encoded string to tokenize.
/// A list of tokens.
/// If the tokenization failed.
- public List Tokenize(string text)
+ public List Tokenize(string text, string encoding = "UTF-8")
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
var n_ctx = NativeApi.llama_n_ctx(_ctx);
var tokens = new llama_token[n_ctx];
- var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true);
+ var n_tokens = NativeApi.llama_tokenize(_ctx, text, Encoding.GetEncoding(encoding), tokens, n_ctx, true);
if (n_tokens < 0)
{
throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}");
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 71bfcc6b..07ac4efe 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -176,8 +176,27 @@ namespace LLama.Native
///
///
///
- [DllImport(libraryName)]
- public static extern int llama_tokenize(SafeLLamaContextHandle ctx, string text, llama_token[] tokens, int n_max_tokens, bool add_bos);
+ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos)
+ {
+ var bytes = encoding.GetBytes(text);
+ sbyte[] data = new sbyte[bytes.Length];
+ for(int i = 0; i < bytes.Length; i++)
+ {
+ data[i] = (sbyte)bytes[i];
+ //if (bytes[i] < 128)
+ //{
+ // data[i] = (sbyte)bytes[i];
+ //}
+ //else
+ //{
+ // data[i] = (sbyte)(~((sbyte)(~bytes[i] + 1)) + 1);
+ //}
+ }
+ return llama_tokenize_native(ctx, data, tokens, n_max_tokens, add_bos);
+ }
+
+ [DllImport(libraryName, EntryPoint = "llama_tokenize")]
+ public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, sbyte[] text, llama_token[] tokens, int n_max_tokens, bool add_bos);
[DllImport(libraryName)]
public static extern int llama_n_vocab(SafeLLamaContextHandle ctx);
diff --git a/LLama/Utils.cs b/LLama/Utils.cs
index 5c04a30f..911ad61e 100644
--- a/LLama/Utils.cs
+++ b/LLama/Utils.cs
@@ -52,11 +52,12 @@ namespace LLama
return ctx;
}
- public static List llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encoding)
+ public static List llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encodingName)
{
- var cnt = Encoding.GetEncoding(encoding).GetByteCount(text);
+ var encoding = Encoding.GetEncoding(encodingName);
+ var cnt = encoding.GetByteCount(text);
llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)];
- int n = NativeApi.llama_tokenize(ctx, text, res, res.Length, add_bos);
+ int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos);
if(n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +