fix: encoding error when using other languages.

This commit is contained in:
Yaohui Liu 2023-06-03 18:51:20 +08:00
parent 4d34f0b116
commit 3a62f087fe
No known key found for this signature in database
GPG Key ID: E86D01E1809BD23E
3 changed files with 28 additions and 7 deletions

View File

@ -7,6 +7,7 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
namespace LLama
{
@ -463,12 +464,12 @@ namespace LLama
/// <param name="text">The utf-8 encoded string to tokenize.</param>
/// <returns>A list of tokens.</returns>
/// <exception cref="RuntimeError">If the tokenization failed.</exception>
public List<llama_token> Tokenize(string text)
public List<llama_token> Tokenize(string text, string encoding = "UTF-8")
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
var n_ctx = NativeApi.llama_n_ctx(_ctx);
var tokens = new llama_token[n_ctx];
var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true);
var n_tokens = NativeApi.llama_tokenize(_ctx, text, Encoding.GetEncoding(encoding), tokens, n_ctx, true);
if (n_tokens < 0)
{
throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}");

View File

@ -176,8 +176,27 @@ namespace LLama.Native
/// <param name="n_max_tokens"></param>
/// <param name="add_bos"></param>
/// <returns></returns>
[DllImport(libraryName)]
public static extern int llama_tokenize(SafeLLamaContextHandle ctx, string text, llama_token[] tokens, int n_max_tokens, bool add_bos);
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos)
{
var bytes = encoding.GetBytes(text);
sbyte[] data = new sbyte[bytes.Length];
for(int i = 0; i < bytes.Length; i++)
{
data[i] = (sbyte)bytes[i];
//if (bytes[i] < 128)
//{
// data[i] = (sbyte)bytes[i];
//}
//else
//{
// data[i] = (sbyte)(~((sbyte)(~bytes[i] + 1)) + 1);
//}
}
return llama_tokenize_native(ctx, data, tokens, n_max_tokens, add_bos);
}
[DllImport(libraryName, EntryPoint = "llama_tokenize")]
public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, sbyte[] text, llama_token[] tokens, int n_max_tokens, bool add_bos);
[DllImport(libraryName)]
public static extern int llama_n_vocab(SafeLLamaContextHandle ctx);

View File

@ -52,11 +52,12 @@ namespace LLama
return ctx;
}
public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encoding)
public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encodingName)
{
var cnt = Encoding.GetEncoding(encoding).GetByteCount(text);
var encoding = Encoding.GetEncoding(encodingName);
var cnt = encoding.GetByteCount(text);
llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)];
int n = NativeApi.llama_tokenize(ctx, text, res, res.Length, add_bos);
int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos);
if(n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +