refactor: remove old version files.

This commit is contained in:
Yaohui Liu 2023-09-02 22:24:07 +08:00
parent b82e9f8fb0
commit 18294a725e
No known key found for this signature in database
GPG Key ID: E86D01E1809BD23E
15 changed files with 1 additions and 1571 deletions

View File

@ -1,39 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using LLama.OldVersion;
namespace LLama.Examples.Old
{
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public class ChatSession
{
LLama.OldVersion.ChatSession<LLama.OldVersion.LLamaModel> _session;
public ChatSession(string modelPath, string promptFilePath, string[] antiprompt)
{
LLama.OldVersion.LLamaModel model = new(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, repeat_penalty: 1.0f, verbose_prompt: false));
_session = new ChatSession<LLama.OldVersion.LLamaModel>(model)
.WithPromptFile(promptFilePath)
.WithAntiprompt(antiprompt);
}
public void Run()
{
Console.Write("\nUser:");
while (true)
{
Console.ForegroundColor = ConsoleColor.Green;
var question = Console.ReadLine();
question += "\n";
Console.ForegroundColor = ConsoleColor.White;
var outputs = _session.Chat(question, encoding: "UTF-8");
foreach (var output in outputs)
{
Console.Write(output);
}
}
}
}
}

View File

@ -1,37 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using LLama.OldVersion;
namespace LLama.Examples.Old
{
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public class ChatWithLLamaModel
{
LLama.OldVersion.LLamaModel _model;
public ChatWithLLamaModel(string modelPath, string promptFilePath, string[] antiprompt)
{
_model = new LLama.OldVersion.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 512, interactive: true, antiprompt: antiprompt.ToList(),
repeat_penalty: 1.0f)).WithPromptFile(promptFilePath);
}
public void Run()
{
Console.Write("\nUser:");
while (true)
{
Console.ForegroundColor = ConsoleColor.Green;
var question = Console.ReadLine();
question += "\n";
Console.ForegroundColor = ConsoleColor.White;
var outputs = _model.Call(question);
foreach (var output in outputs)
{
Console.Write(output);
}
}
}
}
}

View File

@ -1,24 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using LLama.OldVersion;
namespace LLama.Examples.Old
{
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public class GetEmbeddings
{
LLama.OldVersion.LLamaEmbedder _embedder;
public GetEmbeddings(string modelPath)
{
_embedder = new LLama.OldVersion.LLamaEmbedder(new LLamaParams(model: modelPath));
}
public void Run(string text)
{
Console.WriteLine(string.Join(", ", _embedder.GetEmbeddings(text)));
}
}
}

View File

@ -1,37 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using LLama.OldVersion;
namespace LLama.Examples.Old
{
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public class InstructMode
{
LLama.OldVersion.LLamaModel _model;
public InstructMode(string modelPath, string promptFile)
{
_model = new LLama.OldVersion.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true,
repeat_penalty: 1.1f, n_batch: 256, temp: 0.2f)).WithPromptFile(promptFile);
}
public void Run()
{
Console.WriteLine("\n### Instruction:\n >");
while (true)
{
Console.ForegroundColor = ConsoleColor.Green;
var question = Console.ReadLine();
question += "\n";
Console.ForegroundColor = ConsoleColor.White;
var outputs = _model.Call(question);
foreach (var output in outputs)
{
Console.Write(output);
}
}
}
}
}

View File

@ -1,28 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace LLama.Examples.Old
{
public class Quantize
{
public Quantize()
{
}
public void Run(string srcFileName, string dstFilename, string ftype, int nthread = -1)
{
if(LLamaQuantizer.Quantize(srcFileName, dstFilename, ftype, nthread))
{
Console.WriteLine("Quantization succeed!");
}
else
{
Console.WriteLine("Quantization failed!");
}
}
}
}

View File

@ -1,51 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using LLama.OldVersion;
namespace LLama.Examples.Old
{
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public class SaveAndLoadState: IDisposable
{
LLama.OldVersion.LLamaModel _model;
public SaveAndLoadState(string modelPath, string prompt)
{
_model = new LLama.OldVersion.LLamaModel(new LLamaParams(model: modelPath, n_ctx: 2048, n_predict: -1, top_k: 10000, instruct: true,
repeat_penalty: 1.1f, n_batch: 256, temp: 0.2f)).WithPrompt(prompt);
}
public void Run(string question)
{
// Only run once here.
Console.Write("\nUser:");
Console.ForegroundColor = ConsoleColor.Green;
Console.WriteLine(question);
Console.ForegroundColor = ConsoleColor.White;
var outputs = _model.Call(question);
foreach (var output in outputs)
{
Console.Write(output);
}
}
public void SaveState(string filename)
{
_model.SaveState(filename);
Console.WriteLine("Saved state!");
}
public void LoadState(string filename)
{
_model.LoadState(filename);
Console.WriteLine("Loaded state!");
}
public void Dispose()
{
_model.Dispose();
}
}
}

View File

@ -1,99 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace LLama.Examples.Old
{
public class OldTestRunner
{
public static void Run()
{
Console.WriteLine("================LLamaSharp Examples (Old Version)==================\n");
Console.WriteLine("Please input a number to choose an example to run:");
Console.WriteLine("0: Run a chat session.");
Console.WriteLine("1: Run a LLamaModel to chat.");
Console.WriteLine("2: Quantize a model.");
Console.WriteLine("3: Get the embeddings of a message.");
Console.WriteLine("4: Run a LLamaModel with instruct mode.");
Console.WriteLine("5: Load and save state of LLamaModel.");
while (true)
{
Console.Write("\nYour choice: ");
int choice = int.Parse(Console.ReadLine());
if (choice == 0)
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
ChatSession chat = new(modelPath, "Assets/chat-with-bob.txt", new string[] { "User:" });
chat.Run();
}
else if (choice == 1)
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
ChatWithLLamaModel chat = new(modelPath, "Assets/chat-with-bob.txt", new string[] { "User:" });
chat.Run();
}
else if (choice == 2) // quantization
{
Console.Write("Please input your original model path: ");
var inputPath = Console.ReadLine();
Console.Write("Please input your output model path: ");
var outputPath = Console.ReadLine();
Console.Write("Please input the quantize type (one of q4_0, q4_1, q5_0, q5_1, q8_0): ");
var quantizeType = Console.ReadLine();
Quantize q = new Quantize();
q.Run(inputPath, outputPath, quantizeType);
}
else if (choice == 3) // get the embeddings only
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
GetEmbeddings em = new GetEmbeddings(modelPath);
Console.Write("Please input the text: ");
var text = Console.ReadLine();
em.Run(text);
}
else if (choice == 4) // instruct mode
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
InstructMode im = new InstructMode(modelPath, "Assets/alpaca.txt");
Console.WriteLine("Here's a simple example for using instruct mode. You can input some words and let AI " +
"complete it for you. For example: Write a story about a fox that wants to make friend with human. No less than 200 words.");
im.Run();
}
else if (choice == 5) // load and save state
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
Console.Write("Please input your state file path: ");
var statePath = Console.ReadLine();
SaveAndLoadState sals = new(modelPath, File.ReadAllText(@"D:\development\llama\llama.cpp\prompts\alpaca.txt"));
sals.Run("Write a story about a fox that wants to make friend with human. No less than 200 words.");
sals.SaveState(statePath);
sals.Dispose();
GC.Collect();
GC.WaitForPendingFinalizers();
// create a new model to load the state.
SaveAndLoadState sals2 = new(modelPath, "");
sals2.LoadState(statePath);
sals2.Run("Tell me more things about the fox in the story you told me.");
}
else
{
Console.WriteLine("Cannot parse your choice. Please select again.");
continue;
}
break;
}
}
}
}

View File

@ -1,5 +1,4 @@
using LLama.Examples.NewVersion;
using LLama.Examples.Old;
Console.WriteLine("======================================================================================================");
@ -11,19 +10,4 @@ Console.WriteLine("=============================================================
Console.WriteLine();
Console.WriteLine("Please choose the version you want to test: ");
Console.WriteLine("0. old version (for v0.3.0 or earlier version)");
Console.WriteLine("1. new version (for versions after v0.4.0)");
Console.Write("\nYour Choice: ");
int version = int.Parse(Console.ReadLine());
Console.WriteLine();
if(version == 1)
{
await NewVersionTestRunner.Run();
}
else
{
OldTestRunner.Run();
}

View File

@ -1,55 +0,0 @@
using System;
using System.Collections.Generic;
using System.IO;
#pragma warning disable
// ReSharper disable all
namespace LLama.OldVersion
{
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public class ChatSession<T> where T : IChatModel
{
IChatModel _model;
List<ChatMessageRecord> History { get; } = new List<ChatMessageRecord>();
public ChatSession(T model)
{
_model = model;
}
public IEnumerable<string> Chat(string text, string? prompt = null, string encoding = "UTF-8")
{
History.Add(new ChatMessageRecord(new ChatCompletionMessage(ChatRole.Human, text), DateTime.Now));
string totalResponse = "";
foreach (var response in _model.Chat(text, prompt, encoding))
{
totalResponse += response;
yield return response;
}
History.Add(new ChatMessageRecord(new ChatCompletionMessage(ChatRole.Assistant, totalResponse), DateTime.Now));
}
public ChatSession<T> WithPrompt(string prompt, string encoding = "UTF-8")
{
_model.InitChatPrompt(prompt, encoding);
return this;
}
public ChatSession<T> WithPromptFile(string promptFilename, string encoding = "UTF-8")
{
return WithPrompt(File.ReadAllText(promptFilename), encoding);
}
/// <summary>
/// Set the keywords to split the return value of chat AI.
/// </summary>
/// <param name="antiprompt"></param>
/// <returns></returns>
public ChatSession<T> WithAntiprompt(string[] antiprompt)
{
_model.InitChatAntiprompt(antiprompt);
return this;
}
}
}

View File

@ -1,21 +0,0 @@
using System;
using System.Collections.Generic;
#pragma warning disable
// ReSharper disable all
namespace LLama.OldVersion
{
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public interface IChatModel
{
string Name { get; }
IEnumerable<string> Chat(string text, string? prompt = null, string encoding = "UTF-8");
/// <summary>
/// Init a prompt for chat and automatically produce the next prompt during the chat.
/// </summary>
/// <param name="prompt"></param>
void InitChatPrompt(string prompt, string encoding = "UTF-8");
void InitChatAntiprompt(string[] antiprompt);
}
}

View File

@ -1,72 +0,0 @@
using LLama.Native;
using System;
using LLama.Exceptions;
#pragma warning disable
// ReSharper disable all
namespace LLama.OldVersion
{
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public class LLamaEmbedder
: IDisposable
{
SafeLLamaContextHandle _ctx;
/// <summary>
/// Warning: must ensure the original model has params.embedding = true;
/// </summary>
/// <param name="ctx"></param>
internal LLamaEmbedder(SafeLLamaContextHandle ctx)
{
_ctx = ctx;
}
public LLamaEmbedder(LLamaParams @params)
{
@params.embedding = true;
_ctx = Utils.llama_init_from_gpt_params(ref @params);
}
public unsafe float[] GetEmbeddings(string text, int n_thread = -1, bool add_bos = true, string encoding = "UTF-8")
{
if (n_thread == -1)
{
n_thread = Math.Max(Environment.ProcessorCount / 2, 1);
}
int n_past = 0;
if (add_bos)
{
text = text.Insert(0, " ");
}
var embed_inp = Utils.llama_tokenize(_ctx, text, add_bos, encoding);
// TODO(Rinne): deal with log of prompt
if (embed_inp.Count > 0)
{
var embed_inp_array = embed_inp.ToArray();
if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, n_past, n_thread) != 0)
{
throw new RuntimeError("Failed to eval.");
}
}
int n_embed = NativeApi.llama_n_embd(_ctx);
var embeddings = NativeApi.llama_get_embeddings(_ctx);
if (embeddings == null)
{
return new float[0];
}
var span = new Span<float>(embeddings, n_embed);
float[] res = new float[n_embed];
span.CopyTo(res.AsSpan());
return res;
}
public void Dispose()
{
_ctx.Dispose();
}
}
}

View File

@ -1,797 +0,0 @@
using LLama.Exceptions;
using LLama.Extensions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using LLama.Common;
#pragma warning disable
// ReSharper disable all
namespace LLama.OldVersion
{
using llama_token = Int32;
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public class LLamaModel
: IChatModel, IDisposable
{
LLamaParams _params;
SafeLLamaContextHandle _ctx;
string _path_session;
List<llama_token> _session_tokens;
List<llama_token> _embed_inp;
int _n_ctx;
List<llama_token> _inp_pfx;
List<llama_token> _inp_sfx;
List<llama_token> _llama_token_newline;
List<llama_token> _last_n_tokens;
bool _is_interacting;
bool _is_antiprompt;
bool _input_echo;
// HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
// if we loaded a session with at least 75% similarity. It's currently just used to speed up the
// initial prompt so it doesn't need to be an exact match.
bool _need_to_save_session;
int _n_past;
int _n_remain;
int _n_consumed;
int _n_session_consumed;
List<llama_token> _embed;
public string Name { get; set; }
public bool Verbose { get; set; }
public SafeLLamaContextHandle NativeHandle => _ctx;
/// <summary>
/// Please refer `LLamaParams` to find the meanings of each arg. Be sure to have set the `n_gpu_layers`, otherwise it will
/// load 20 layers to gpu by default.
/// </summary>
/// <param name="model_path">The model file path.</param>
/// <param name="model_name">The model name.</param>
/// <param name="verbose">Whether to print details when running the model.</param>
/// <param name="seed"></param>
/// <param name="n_threads"></param>
/// <param name="n_predict"></param>
/// <param name="n_ctx"></param>
/// <param name="n_batch"></param>
/// <param name="n_keep"></param>
/// <param name="n_gpu_layers"></param>
/// <param name="logit_bias"></param>
/// <param name="top_k"></param>
/// <param name="top_p"></param>
/// <param name="tfs_z"></param>
/// <param name="typical_p"></param>
/// <param name="temp"></param>
/// <param name="repeat_penalty"></param>
/// <param name="repeat_last_n"></param>
/// <param name="frequency_penalty"></param>
/// <param name="presence_penalty"></param>
/// <param name="mirostat"></param>
/// <param name="mirostat_tau"></param>
/// <param name="mirostat_eta"></param>
/// <param name="prompt"></param>
/// <param name="path_session"></param>
/// <param name="input_prefix"></param>
/// <param name="input_suffix"></param>
/// <param name="antiprompt"></param>
/// <param name="lora_adapter"></param>
/// <param name="lora_base"></param>
/// <param name="memory_f16"></param>
/// <param name="random_prompt"></param>
/// <param name="use_color"></param>
/// <param name="interactive"></param>
/// <param name="embedding"></param>
/// <param name="interactive_first"></param>
/// <param name="prompt_cache_all"></param>
/// <param name="instruct"></param>
/// <param name="penalize_nl"></param>
/// <param name="perplexity"></param>
/// <param name="use_mmap"></param>
/// <param name="use_mlock"></param>
/// <param name="mem_test"></param>
/// <param name="verbose_prompt"></param>
/// <param name="encoding"></param>
public LLamaModel(string model_path, string model_name, bool verbose = false, int seed = 0, int n_threads = -1, int n_predict = -1,
int n_ctx = 512, int n_batch = 512, int n_keep = 0, int n_gpu_layers = -1,
Dictionary<llama_token, float> logit_bias = null, int top_k = 40, float top_p = 0.95f,
float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f,
int repeat_last_n = 64, float frequency_penalty = 0.00f, float presence_penalty = 0.00f,
int mirostat = 0, float mirostat_tau = 5.00f, float mirostat_eta = 0.10f, string prompt = "",
string path_session = "", string input_prefix = "", string input_suffix = "",
List<string> antiprompt = null, string lora_adapter = "", string lora_base = "",
bool memory_f16 = true, bool random_prompt = false, bool use_color = false, bool interactive = false,
bool embedding = false, bool interactive_first = false, bool prompt_cache_all = false, bool instruct = false, bool penalize_nl = true,
bool perplexity = false, bool use_mmap = true, bool use_mlock = false, bool mem_test = false,
bool verbose_prompt = false, string encoding = "UTF-8") : this(new LLamaParams(seed: seed,
n_threads: n_threads,
n_predict: n_predict,
n_ctx: n_ctx,
n_batch: n_batch,
n_keep: n_keep,
n_gpu_layers: n_gpu_layers,
logit_bias: logit_bias,
top_k: top_k,
top_p: top_p,
tfs_z: tfs_z,
typical_p: typical_p,
temp: temp,
repeat_penalty: repeat_penalty,
repeat_last_n: repeat_last_n,
frequency_penalty: frequency_penalty,
presence_penalty: presence_penalty,
mirostat: mirostat,
mirostat_tau: mirostat_tau,
mirostat_eta: mirostat_eta,
model: model_path,
prompt: prompt,
path_session: path_session,
input_prefix: input_prefix,
input_suffix: input_suffix,
antiprompt: antiprompt,
lora_adapter: lora_adapter,
lora_base: lora_base,
memory_f16: memory_f16,
random_prompt: random_prompt,
use_color: use_color,
interactive: interactive,
embedding: embedding,
interactive_first: interactive_first,
prompt_cache_all: prompt_cache_all,
instruct: instruct,
penalize_nl: penalize_nl,
perplexity: perplexity,
use_mmap: use_mmap,
use_mlock: use_mlock,
mem_test: mem_test,
verbose_prompt: verbose_prompt),
model_name, verbose, encoding)
{
}
/// <summary>
/// Please refer `LLamaParams` to find the meanings of each arg. Be sure to have set the `n_gpu_layers`, otherwise it will
/// load 20 layers to gpu by default.
/// </summary>
/// <param name="params">The LLamaModel params</param>
/// <param name="name">Model name</param>
/// <param name="verbose">Whether to output the detailed info.</param>
/// <param name="encoding"></param>
/// <exception cref="RuntimeError"></exception>
public unsafe LLamaModel(LLamaParams @params, string name = "", bool verbose = false, string encoding = "UTF-8")
{
Name = name;
_params = @params;
Verbose = verbose;
_ctx = Utils.llama_init_from_gpt_params(ref _params);
// Add a space in front of the first character to match OG llama tokenizer behavior
_session_tokens = new List<llama_token>();
_path_session = @params.path_session;
if (!string.IsNullOrEmpty(_path_session))
{
if (verbose)
{
LLamaDefaultLogger.Default.Info($"Attempting to load saved session from '{_path_session}'");
}
if (!File.Exists(_path_session))
{
LLamaDefaultLogger.Default.Warn("Session file does not exist, will create.");
}
llama_token[] session_tokens = new llama_token[@params.n_ctx];
ulong n_token_count_out = 0;
if (!NativeApi.llama_load_session_file(_ctx, _path_session, session_tokens, (ulong)@params.n_ctx, &n_token_count_out))
{
throw new RuntimeError($"Failed to load session file {_path_session}");
}
_session_tokens = session_tokens.Take((int)n_token_count_out).ToList();
if (verbose)
{
LLamaDefaultLogger.Default.Info($"Loaded a session with prompt size of {_session_tokens.Count} tokens");
}
}
_n_ctx = NativeApi.llama_n_ctx(_ctx);
WithPrompt(_params.prompt);
// prefix & suffix for instruct mode
_inp_pfx = Utils.llama_tokenize(_ctx, "\n\n### Instruction:\n\n", true, encoding);
_inp_sfx = Utils.llama_tokenize(_ctx, "\n\n### Response:\n\n", false, encoding);
// in instruct mode, we inject a prefix and a suffix to each input by the user
if (_params.instruct)
{
_params.interactive_first = true;
_params.antiprompt.Add("### Instruction:\n\n");
}
// enable interactive mode if reverse prompt or interactive start is specified
if (_params.interactive_first)
{
_params.interactive = true;
}
// determine newline token
_llama_token_newline = Utils.llama_tokenize(_ctx, "\n", false, encoding);
if (_params.verbose_prompt)
{
LLamaDefaultLogger.Default.Info("\n");
LLamaDefaultLogger.Default.Info($"prompt: '{_params.prompt}'");
LLamaDefaultLogger.Default.Info($"number of tokens in prompt = {_embed_inp.Count}");
for (int i = 0; i < _embed_inp.Count; i++)
{
LLamaDefaultLogger.Default.Info($"{_embed_inp[i]} -> '{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}'");
}
if (_params.n_keep > 0)
{
LLamaDefaultLogger.Default.Info($"static prompt based on n_keep: '");
for (int i = 0; i < _params.n_keep; i++)
{
LLamaDefaultLogger.Default.Info($"{NativeApi.llama_token_to_str(_ctx, _embed_inp[i])}");
}
LLamaDefaultLogger.Default.Info("\n");
}
LLamaDefaultLogger.Default.Info("\n");
}
if (_params.interactive && verbose)
{
LLamaDefaultLogger.Default.Info("interactive mode on.");
}
if (verbose)
{
LLamaDefaultLogger.Default.Info($"sampling: repeat_last_n = {_params.repeat_last_n}, " +
$"repeat_penalty = {_params.repeat_penalty}, presence_penalty = {_params.presence_penalty}, " +
$"frequency_penalty = {_params.frequency_penalty}, top_k = {_params.top_k}, tfs_z = {_params.tfs_z}," +
$" top_p = {_params.top_p}, typical_p = {_params.typical_p}, temp = {_params.temp}, mirostat = {_params.mirostat}," +
$" mirostat_lr = {_params.mirostat_eta}, mirostat_ent = {_params.mirostat_tau}");
LLamaDefaultLogger.Default.Info($"generate: n_ctx = {_n_ctx}, n_batch = {_params.n_batch}, n_predict = {_params.n_predict}, " +
$"n_keep = {_params.n_keep}");
LLamaDefaultLogger.Default.Info("\n");
}
_last_n_tokens = Enumerable.Repeat(0, _n_ctx).ToList();
if (_params.interactive)
{
if (verbose)
{
LLamaDefaultLogger.Default.Info("== Running in interactive mode. ==");
}
_is_interacting = _params.interactive_first;
}
_is_antiprompt = false;
_input_echo = false;
_n_past = 0;
_n_remain = _params.n_predict;
_n_consumed = 0;
_n_session_consumed = 0;
_embed = new List<llama_token>();
}
/// <summary>
/// Apply a prompt to the model.
/// </summary>
/// <param name="prompt"></param>
/// <param name="encoding"></param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public LLamaModel WithPrompt(string prompt, string encoding = "UTF-8")
{
_params.prompt = prompt.Insert(0, " ");
_embed_inp = Utils.llama_tokenize(_ctx, _params.prompt, true, encoding);
if (_embed_inp.Count > _n_ctx - 4)
{
throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
}
ulong n_matching_session_tokens = 0;
if (_session_tokens.Count > 0)
{
foreach (var id in _session_tokens)
{
if (n_matching_session_tokens >= (ulong)_embed_inp.Count || id != _embed_inp[(int)n_matching_session_tokens])
{
break;
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= (ulong)_embed_inp.Count)
{
LLamaDefaultLogger.Default.Info("Session file has exact match for prompt!");
}
else if (n_matching_session_tokens < (ulong)(_embed_inp.Count / 2))
{
LLamaDefaultLogger.Default.Warn($"session file has low similarity to prompt ({n_matching_session_tokens} " +
$"/ {_embed_inp.Count} tokens); will mostly be reevaluated.");
}
else
{
LLamaDefaultLogger.Default.Info($"Session file matches {n_matching_session_tokens} / {_embed_inp.Count} " +
$"tokens of prompt.");
}
}
// number of tokens to keep when resetting context
if (_params.n_keep < 0 || _params.n_keep > _embed_inp.Count || _params.instruct)
{
_params.n_keep = _embed_inp.Count;
}
if (_embed_inp.Count > _n_ctx - 4)
{
throw new ArgumentException($"prompt is too long ({_embed_inp.Count} tokens, max {_n_ctx - 4})");
}
_need_to_save_session = !string.IsNullOrEmpty(_path_session) && n_matching_session_tokens < (ulong)(_embed_inp.Count * 3 / 4);
return this;
}
/// <summary>
/// Apply the prompt file to the model.
/// </summary>
/// <param name="promptFileName"></param>
/// <returns></returns>
public LLamaModel WithPromptFile(string promptFileName)
{
return WithPrompt(File.ReadAllText(promptFileName));
}
private void ProcessTextBeforeInfer(string text, string encoding)
{
if (!string.IsNullOrEmpty(_params.input_prefix))
{
text = _params.input_prefix + text;
}
//if (!text.EndsWith("\n"))
//{
// text += "\n";
//}
if (text.Length > 1)
{
// append input suffix if any
if (!string.IsNullOrEmpty(_params.input_suffix))
{
text += _params.input_suffix;
//yield return _params.input_suffix;
}
// instruct mode: insert instruction prefix
if (_params.instruct && !_is_antiprompt)
{
_n_consumed = _embed_inp.Count;
_embed_inp.AddRange(_inp_pfx);
}
var line_inp = Utils.llama_tokenize(_ctx, text, false, encoding);
_embed_inp.AddRange(line_inp);
// instruct mode: insert response suffix
if (_params.instruct)
{
_embed_inp.AddRange(_inp_sfx);
}
_n_remain -= line_inp.Count;
}
}
public void InitChatPrompt(string prompt, string encoding = "UTF-8")
{
WithPrompt(prompt);
}
public void InitChatAntiprompt(string[] antiprompt)
{
_params.antiprompt = antiprompt.ToList();
}
/// <summary>
/// Chat with the LLaMa model under interactive mode.
/// </summary>
/// <param name="text"></param>
/// <param name="prompt"></param>
/// <param name="encoding"></param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public IEnumerable<string> Chat(string text, string? prompt = null, string encoding = "UTF-8")
{
if (!_params.interactive)
{
throw new ArgumentException("The chat API could be only used under interactive model.");
}
_input_echo = false;
if (!string.IsNullOrEmpty(prompt))
{
WithPrompt(prompt);
}
return Call(text, encoding);
}
/// <summary>
/// Save the state to specified path.
/// </summary>
/// <param name="filename"></param>
public void SaveState(string filename)
{
var stateSize = NativeApi.llama_get_state_size(_ctx);
byte[] stateMemory = new byte[stateSize];
NativeApi.llama_copy_state_data(_ctx, stateMemory);
File.WriteAllBytes(filename, stateMemory);
}
/// <summary>
/// Load the state from specified path.
/// </summary>
/// <param name="filename"></param>
/// <param name="clearPreviousEmbed">Whether to clear previous footprints of this model.</param>
/// <exception cref="RuntimeError"></exception>
public void LoadState(string filename, bool clearPreviousEmbed = true)
{
var stateMemory = File.ReadAllBytes(filename);
int stateSize = (int)NativeApi.llama_get_state_size(_ctx);
if (stateMemory.Length != stateSize)
{
throw new RuntimeError("Failed to validate state size.");
}
NativeApi.llama_set_state_data(_ctx, stateMemory);
if (clearPreviousEmbed)
{
WithPrompt(_params.prompt);
}
}
/// <summary>
/// Tokenize a string.
/// </summary>
/// <param name="text">The utf-8 encoded string to tokenize.</param>
/// <returns>A list of tokens.</returns>
/// <exception cref="RuntimeError">If the tokenization failed.</exception>
public List<llama_token> Tokenize(string text, string encoding = "UTF-8")
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
var n_ctx = NativeApi.llama_n_ctx(_ctx);
var tokens = new llama_token[n_ctx];
var n_tokens = NativeApi.llama_tokenize(_ctx, text, Encoding.GetEncoding(encoding), tokens, n_ctx, true);
if (n_tokens < 0)
{
throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}");
}
return tokens.Take(n_tokens).ToList();
}
/// <summary>
/// Detokenize a list of tokens.
/// </summary>
/// <param name="tokens">The list of tokens to detokenize.</param>
/// <returns>The detokenized string.</returns>
public string DeTokenize(IEnumerable<llama_token> tokens)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
string output = "";
foreach (var token in tokens)
{
output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token));
}
return output;
}
/// <summary>
/// Call the model to run inference.
/// </summary>
/// <param name="text"></param>
/// <param name="encoding"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public IEnumerable<string> Call(string text, string encoding = "UTF-8")
{
_is_antiprompt = false;
if (_n_past > 0)
{
_is_interacting = false;
}
if (_is_interacting)
{
if (Verbose)
{
LLamaDefaultLogger.Default.Warn("In interacting when calling the model, automatically changed it.");
}
_is_interacting = false;
}
ProcessTextBeforeInfer(text, encoding);
while ((_n_remain != 0 || _params.interactive) && !_is_interacting)
{
if (_embed.Count > 0)
{
// infinite text generation via context swapping
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (_n_past + _embed.Count > _n_ctx)
{
int n_left = _n_past - _params.n_keep;
_n_past = Math.Max(1, _params.n_keep);
// insert n_left/2 tokens at the start of embed from last_n_tokens
_embed.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embed.Count).Skip(_n_ctx - n_left / 2 - _embed.Count));
// stop saving session if we run out of context
_path_session = "";
}
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
// REVIEW
if (_n_session_consumed < _session_tokens.Count)
{
int i = 0;
for (; i < _embed.Count; i++)
{
if (_embed[i] != _session_tokens[_n_session_consumed])
{
_session_tokens = _session_tokens.Take(_n_session_consumed).ToList();
break;
}
_n_past++;
_n_session_consumed++;
if (_n_session_consumed >= _session_tokens.Count)
{
i++;
break;
}
}
if (i > 0)
{
_embed.RemoveRange(0, i);
}
}
// evaluate tokens in batches
// embed is typically prepared beforehand to fit within a batch, but not always
for (int i = 0; i < _embed.Count; i += _params.n_batch)
{
int n_eval = _embed.Count - i;
if (n_eval > _params.n_batch)
{
n_eval = _params.n_batch;
}
var array = _embed.Skip(i).ToArray();
if (NativeApi.llama_eval(_ctx, array, n_eval, _n_past, _params.n_threads) != 0)
{
LLamaDefaultLogger.Default.Error($"Failed to eval.");
throw new RuntimeError("Failed to eval.");
}
_n_past += n_eval;
}
if (_embed.Count > 0 && !string.IsNullOrEmpty(_path_session))
{
_session_tokens.AddRange(_embed);
_n_session_consumed = _session_tokens.Count;
}
}
_embed.Clear();
if (_embed_inp.Count <= _n_consumed && !_is_interacting)
{
var temp = _params.temp;
var top_k = _params.top_k <= 0 ? NativeApi.llama_n_vocab(_ctx) : _params.top_k;
var top_p = _params.top_p;
var tfs_z = _params.tfs_z;
var typical_p = _params.typical_p;
var repeat_last_n = _params.repeat_last_n < 0 ? _n_ctx : _params.repeat_last_n;
var repeat_penalty = _params.repeat_penalty;
var alpha_presence = _params.presence_penalty;
var alpha_frequency = _params.frequency_penalty;
var mirostat = _params.mirostat;
var mirostat_tau = _params.mirostat_tau;
var mirostat_eta = _params.mirostat_eta;
var penalize_nl = _params.penalize_nl;
// optionally save the session on first sample (for faster prompt loading next time)
if (!string.IsNullOrEmpty(_path_session) && _need_to_save_session)
{
_need_to_save_session = false;
NativeApi.llama_save_session_file(_ctx, _path_session, _session_tokens.ToArray(), (ulong)_session_tokens.Count);
}
llama_token id;
{
var n_vocab = NativeApi.llama_n_vocab(_ctx);
var logits = Utils.llama_get_logits(_ctx, n_vocab);
// Apply params.logit_bias map
foreach (var (key, value) in _params.logit_bias)
{
logits[key] += value;
}
var candidates = new LLamaTokenData[n_vocab];
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates);
// Apply penalties
float nl_logit = logits[NativeApi.llama_token_nl(_ctx)];
var last_n_repeat = Math.Min(Math.Min(_last_n_tokens.Count, repeat_last_n), _n_ctx);
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p,
_last_n_tokens.Skip(_last_n_tokens.Count - last_n_repeat).ToArray(),
(ulong)last_n_repeat, repeat_penalty);
SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p,
_last_n_tokens.Skip(_last_n_tokens.Count - last_n_repeat).ToArray(),
(ulong)last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl)
{
logits[NativeApi.llama_token_nl(_ctx)] = nl_logit;
}
if (temp <= 0)
{
// Greedy sampling
id = SamplingApi.llama_sample_token_greedy(_ctx, candidates_p);
}
else
{
if (mirostat == 1)
{
float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates_p, mirostat_tau, mirostat_eta, mirostat_m, ref mirostat_mu);
}
else if (mirostat == 2)
{
float mirostat_mu = 2.0f * mirostat_tau;
SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates_p, mirostat_tau, mirostat_eta, ref mirostat_mu);
}
else
{
// Temperature sampling
SamplingApi.llama_sample_top_k(_ctx, candidates_p, top_k, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates_p, tfs_z, 1);
SamplingApi.llama_sample_typical(_ctx, candidates_p, typical_p, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates_p, top_p, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates_p, temp);
id = SamplingApi.llama_sample_token(_ctx, candidates_p);
}
}
_last_n_tokens.RemoveAt(0);
_last_n_tokens.Add(id);
}
// replace end of text token with newline token when in interactive mode
if (id == NativeApi.llama_token_eos(_ctx) && _params.interactive && !_params.instruct)
{
id = _llama_token_newline[0];
if (_params.antiprompt.Count != 0)
{
// tokenize and inject first reverse prompt
var first_antiprompt = Utils.llama_tokenize(_ctx, _params.antiprompt[0], false, encoding);
_embed_inp.AddRange(first_antiprompt);
}
}
// add it to the context
_embed.Add(id);
// echo this to console
_input_echo = true;
// decrement remaining sampling budget
_n_remain--;
}
else
{
while (_embed_inp.Count > _n_consumed)
{
_embed.Add(_embed_inp[_n_consumed]);
_last_n_tokens.RemoveAt(0);
_last_n_tokens.Add(_embed_inp[_n_consumed]);
_n_consumed++;
if (_embed.Count >= _params.n_batch)
{
break;
}
}
}
if (_input_echo && !_is_interacting)
{
foreach (var id in _embed)
{
var res = Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id));
yield return res;
}
}
if (_params.interactive && _embed_inp.Count <= _n_consumed)
{
if (_params.antiprompt.Count > 0)
{
string last_output = "";
foreach (var id in _last_n_tokens)
{
last_output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, id));
}
_is_antiprompt = false;
foreach (var antiprompt in _params.antiprompt)
{
if (last_output.EndsWith(antiprompt))
{
_is_interacting = true;
_is_antiprompt = true;
break;
}
}
}
if (_n_past > 0 && _is_interacting)
{
if (_params.instruct)
{
yield return "\n> ";
}
_input_echo = false;
break;
}
if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos(_ctx))
{
if (_params.instruct)
{
_is_interacting = true;
}
else
{
LLamaDefaultLogger.Default.Info(" [end of text]");
}
}
if (_params.interactive && _n_remain <= 0 && _params.n_predict != -1)
{
_n_remain = _params.n_predict;
_is_interacting = true;
}
}
}
if (!string.IsNullOrEmpty(_path_session) && _params.prompt_cache_all)
{
LLamaDefaultLogger.Default.Info($"saving final output to session file {_path_session}");
var session_token_array = _session_tokens.ToArray();
NativeApi.llama_save_session_file(_ctx, _path_session, session_token_array, (ulong)session_token_array.Length);
}
}
/// <inheritdoc />
public void Dispose()
{
_ctx.Dispose();
}
}
}

View File

@ -1,142 +0,0 @@
using System;
using System.Collections.Generic;
#pragma warning disable
// ReSharper disable all
namespace LLama.OldVersion
{
using llama_token = Int32;
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public struct LLamaParams
{
public int seed; // RNG seed
public int n_threads = Math.Max(Environment.ProcessorCount / 2, 1); // number of threads (-1 = autodetect)
public int n_predict = -1; // new tokens to predict
public int n_ctx = 512; // context size
public int n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
public int n_keep = 0; // number of tokens to keep from initial prompt
public int n_gpu_layers = -1; // number of layers to store in VRAM
// sampling parameters
public Dictionary<llama_token, float> logit_bias; // logit bias for specific tokens
public int top_k = 40; // <= 0 to use vocab size
public float top_p = 0.95f; // 1.0 = disabled
public float tfs_z = 1.00f; // 1.0 = disabled
public float typical_p = 1.00f; // 1.0 = disabled
public float temp = 0.80f; // 1.0 = disabled
public float repeat_penalty = 1.10f; // 1.0 = disabled
public int repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
public float frequency_penalty = 0.00f; // 0.0 = disabled
public float presence_penalty = 0.00f; // 0.0 = disabled
public int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
public float mirostat_tau = 5.00f; // target entropy
public float mirostat_eta = 0.10f; // learning rate
public string model = "models/lamma-7B/ggml-model.bin"; // model path
public string prompt = ""; // initial prompt (set to empty string for interactive mode)
public string path_session = ""; // path to file for saving/loading model eval state
public string input_prefix = ""; // string to prefix user inputs with
public string input_suffix = ""; // string to suffix user inputs with
public List<string> antiprompt; // string upon seeing which more user input is prompted
public string lora_adapter = ""; // lora adapter path
public string lora_base = ""; // base model path for the lora adapter
public bool memory_f16 = true; // use f16 instead of f32 for memory kv
public bool random_prompt = false; // randomize prompt if none provided
public bool use_color = false; // use color to distinguish generations and inputs
public bool interactive = false; // interactive mode
public bool prompt_cache_all = false; // save user input and generations to prompt cache
public bool embedding = false; // get only sentence embedding
public bool interactive_first = false; // wait for user input immediately
public bool instruct = false; // instruction mode (used for Alpaca models)
public bool penalize_nl = true; // consider newlines as a repeatable token
public bool perplexity = false; // compute perplexity over the prompt
public bool use_mmap = true; // use mmap for faster loads
public bool use_mlock = false; // use mlock to keep model in memory
public bool mem_test = false; // compute maximum memory usage
public bool verbose_prompt = false; // print prompt tokens before generation
public LLamaParams(int seed = 0, int n_threads = -1, int n_predict = -1,
int n_ctx = 512, int n_batch = 512, int n_keep = 0, int n_gpu_layers = -1,
Dictionary<llama_token, float>? logit_bias = null, int top_k = 40, float top_p = 0.95f,
float tfs_z = 1.00f, float typical_p = 1.00f, float temp = 0.80f, float repeat_penalty = 1.10f,
int repeat_last_n = 64, float frequency_penalty = 0.00f, float presence_penalty = 0.00f,
int mirostat = 0, float mirostat_tau = 5.00f, float mirostat_eta = 0.10f,
string model = "models/lamma-7B/ggml-model.bin", string prompt = "",
string path_session = "", string input_prefix = "", string input_suffix = "",
List<string> antiprompt = null, string lora_adapter = "", string lora_base = "",
bool memory_f16 = true, bool random_prompt = false, bool use_color = false, bool interactive = false,
bool prompt_cache_all = false, bool embedding = false, bool interactive_first = false,
bool instruct = false, bool penalize_nl = true,
bool perplexity = false, bool use_mmap = true, bool use_mlock = false, bool mem_test = false,
bool verbose_prompt = false)
{
this.seed = seed;
if (n_threads != -1)
{
this.n_threads = n_threads;
}
this.n_predict = n_predict;
this.n_ctx = n_ctx;
this.n_batch = n_batch;
this.n_keep = n_keep;
this.n_gpu_layers = n_gpu_layers == -1 ? 20 : n_gpu_layers;
if (logit_bias == null)
{
logit_bias = new Dictionary<llama_token, float>();
}
this.logit_bias = logit_bias;
this.top_k = top_k;
this.top_p = top_p;
this.tfs_z = tfs_z;
this.typical_p = typical_p;
this.temp = temp;
this.repeat_penalty = repeat_penalty;
this.repeat_last_n = repeat_last_n;
this.frequency_penalty = frequency_penalty;
this.presence_penalty = presence_penalty;
this.mirostat = mirostat;
this.mirostat_tau = mirostat_tau;
this.mirostat_eta = mirostat_eta;
this.model = model;
this.prompt = prompt;
this.path_session = path_session;
this.input_prefix = input_prefix;
this.input_suffix = input_suffix;
if (antiprompt == null)
{
antiprompt = new List<string>();
}
this.antiprompt = antiprompt;
this.lora_adapter = lora_adapter;
this.lora_base = lora_base;
this.memory_f16 = memory_f16;
this.random_prompt = random_prompt;
this.use_color = use_color;
this.interactive = interactive;
this.prompt_cache_all = prompt_cache_all;
this.embedding = embedding;
this.interactive_first = interactive_first;
this.instruct = instruct;
this.penalize_nl = penalize_nl;
this.perplexity = perplexity;
this.use_mmap = use_mmap;
this.use_mlock = use_mlock;
this.mem_test = mem_test;
this.verbose_prompt = verbose_prompt;
}
}
}

View File

@ -1,59 +0,0 @@
using System;
using System.Collections.Generic;
#pragma warning disable
// ReSharper disable all
namespace LLama.OldVersion
{
public enum ChatRole
{
Human,
Assistant
}
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record EmbeddingUsage(int PromptTokens, int TotalTokens);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record EmbeddingData(int Index, string Object, float[] Embedding);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record Embedding(string Object, string Model, EmbeddingData[] Data, EmbeddingUsage Usage);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record CompletionLogprobs(int[] TextOffset, float[] TokenLogProbs, string[] Tokens, Dictionary<string, float>[] TopLogprobs);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record CompletionChoice(string Text, int Index, CompletionLogprobs? Logprobs, string? FinishReason);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record CompletionUsage(int PromptTokens, int CompletionTokens, int TotalTokens);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record CompletionChunk(string Id, string Object, int Created, string Model, CompletionChoice[] Choices);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record Completion(string Id, string Object, int Created, string Model, CompletionChoice[] Choices, CompletionUsage Usage);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record ChatCompletionMessage(ChatRole Role, string Content, string? Name = null);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record ChatCompletionChoice(int Index, ChatCompletionMessage Message, string? FinishReason);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record ChatCompletion(string Id, string Object, int Created, string Model, ChatCompletionChoice[] Choices, CompletionUsage Usage);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record ChatCompletionChunkDelta(string? Role, string? Content);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record ChatCompletionChunkChoice(int Index, ChatCompletionChunkDelta Delta, string? FinishReason);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record ChatCompletionChunk(string Id, string Model, string Object, int Created, ChatCompletionChunkChoice[] Choices);
[Obsolete("The entire LLama.OldVersion namespace will be removed")]
public record ChatMessageRecord(ChatCompletionMessage Message, DateTime Time);
}

View File

@ -1,93 +0,0 @@
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;
using System.Linq;
using System.Runtime.InteropServices;
using System.IO;
#pragma warning disable
// ReSharper disable all
namespace LLama.OldVersion
{
using llama_token = Int32;
internal static class Utils
{
public static SafeLLamaContextHandle llama_init_from_gpt_params(ref LLamaParams @params)
{
var lparams = NativeApi.llama_context_default_params();
lparams.n_ctx = @params.n_ctx;
lparams.n_gpu_layers = @params.n_gpu_layers;
lparams.seed = @params.seed;
lparams.f16_kv = @params.memory_f16;
lparams.use_mmap = @params.use_mmap;
lparams.use_mlock = @params.use_mlock;
lparams.logits_all = @params.perplexity;
lparams.embedding = @params.embedding;
if (!File.Exists(@params.model))
{
throw new FileNotFoundException($"The model file does not exist: {@params.model}");
}
var model = SafeLlamaModelHandle.LoadFromFile(@params.model, lparams);
var ctx = SafeLLamaContextHandle.Create(model, lparams);
if (!string.IsNullOrEmpty(@params.lora_adapter))
model.ApplyLoraFromFile(@params.lora_adapter, @params.lora_base, @params.n_threads);
return ctx;
}
public static List<llama_token> llama_tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, string encodingName)
{
var encoding = Encoding.GetEncoding(encodingName);
var cnt = encoding.GetByteCount(text);
llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)];
int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos);
if (n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
"specify the encoding.");
}
return res.Take(n).ToList();
}
public static unsafe Span<float> llama_get_logits(SafeLLamaContextHandle ctx, int length)
{
var logits = NativeApi.llama_get_logits(ctx);
return new Span<float>(logits, length);
}
public static unsafe string PtrToStringUTF8(IntPtr ptr)
{
#if NET6_0_OR_GREATER
return Marshal.PtrToStringUTF8(ptr);
#else
unsafe
{
byte* tp = (byte*)ptr.ToPointer();
List<byte> bytes = new();
while (true)
{
byte c = *tp++;
if (c == '\0')
{
break;
}
else
{
bytes.Add(c);
}
}
return Encoding.UTF8.GetString(bytes.ToArray());
}
#endif
}
}
}