fix: encoding error when using other languages.
This commit is contained in:
parent
4d34f0b116
commit
3a62f087fe
|
@ -7,6 +7,7 @@ using System.Collections.Generic;
|
|||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
|
||||
namespace LLama
|
||||
{
|
||||
|
@ -463,12 +464,12 @@ namespace LLama
|
|||
/// <param name="text">The utf-8 encoded string to tokenize.</param>
|
||||
/// <returns>A list of tokens.</returns>
|
||||
/// <exception cref="RuntimeError">If the tokenization failed.</exception>
|
||||
public List<llama_token> Tokenize(string text)
|
||||
public List<llama_token> Tokenize(string text, string encoding = "UTF-8")
|
||||
{
|
||||
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
|
||||
var n_ctx = NativeApi.llama_n_ctx(_ctx);
|
||||
var tokens = new llama_token[n_ctx];
|
||||
var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true);
|
||||
var n_tokens = NativeApi.llama_tokenize(_ctx, text, Encoding.GetEncoding(encoding), tokens, n_ctx, true);
|
||||
if (n_tokens < 0)
|
||||
{
|
||||
throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}");
|
||||
|
|
|
@ -176,8 +176,27 @@ namespace LLama.Native
|
|||
/// <param name="n_max_tokens"></param>
|
||||
/// <param name="add_bos"></param>
|
||||
/// <returns></returns>
|
||||
[DllImport(libraryName)]
|
||||
public static extern int llama_tokenize(SafeLLamaContextHandle ctx, string text, llama_token[] tokens, int n_max_tokens, bool add_bos);
|
||||
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos)
|
||||
{
|
||||
var bytes = encoding.GetBytes(text);
|
||||
sbyte[] data = new sbyte[bytes.Length];
|
||||
for(int i = 0; i < bytes.Length; i++)
|
||||
{
|
||||
data[i] = (sbyte)bytes[i];
|
||||
//if (bytes[i] < 128)
|
||||
//{
|
||||
// data[i] = (sbyte)bytes[i];
|
||||
//}
|
||||
//else
|
||||
//{
|
||||
// data[i] = (sbyte)(~((sbyte)(~bytes[i] + 1)) + 1);
|
||||
//}
|
||||
}
|
||||
return llama_tokenize_native(ctx, data, tokens, n_max_tokens, add_bos);
|
||||
}
|
||||
|
||||
[DllImport(libraryName, EntryPoint = "llama_tokenize")]
|
||||
public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, sbyte[] text, llama_token[] tokens, int n_max_tokens, bool add_bos);
|
||||
|
||||
[DllImport(libraryName)]
|
||||
public static extern int llama_n_vocab(SafeLLamaContextHandle ctx);
|
||||
|
|
|
@ -52,11 +52,12 @@ namespace LLama
|
|||
return ctx;
|
||||
}
|
||||
|
||||
public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encoding)
|
||||
public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encodingName)
|
||||
{
|
||||
var cnt = Encoding.GetEncoding(encoding).GetByteCount(text);
|
||||
var encoding = Encoding.GetEncoding(encodingName);
|
||||
var cnt = encoding.GetByteCount(text);
|
||||
llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)];
|
||||
int n = NativeApi.llama_tokenize(ctx, text, res, res.Length, add_bos);
|
||||
int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos);
|
||||
if(n < 0)
|
||||
{
|
||||
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
|
||||
|
|
Loading…
Reference in New Issue