Merge pull request #64 from martindevans/new_llama_state_loading_mechanism
Low level new loading system
This commit is contained in:
commit
1d29b240b2
|
@ -13,65 +13,101 @@ namespace LLama.Native
|
|||
/// RNG seed, -1 for random
|
||||
/// </summary>
|
||||
public int seed;
|
||||
|
||||
/// <summary>
|
||||
/// text context
|
||||
/// </summary>
|
||||
public int n_ctx;
|
||||
|
||||
/// <summary>
|
||||
/// prompt processing batch size
|
||||
/// </summary>
|
||||
public int n_batch;
|
||||
|
||||
/// <summary>
|
||||
/// grouped-query attention (TEMP - will be moved to model hparams)
|
||||
/// </summary>
|
||||
public int n_gqa;
|
||||
|
||||
/// <summary>
|
||||
/// rms norm epsilon (TEMP - will be moved to model hparams)
|
||||
/// </summary>
|
||||
public float rms_norm_eps;
|
||||
|
||||
/// <summary>
|
||||
/// number of layers to store in VRAM
|
||||
/// </summary>
|
||||
public int n_gpu_layers;
|
||||
|
||||
/// <summary>
|
||||
/// the GPU that is used for scratch and small tensors
|
||||
/// </summary>
|
||||
public int main_gpu;
|
||||
|
||||
/// <summary>
|
||||
/// how to split layers across multiple GPUs
|
||||
/// </summary>
|
||||
public TensorSplits tensor_split;
|
||||
public float[] tensor_split;
|
||||
|
||||
/// <summary>
|
||||
/// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
/// RoPE base frequency
|
||||
/// </summary>
|
||||
public float rope_freq_base;
|
||||
|
||||
/// <summary>
|
||||
/// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
/// RoPE frequency scaling factor
|
||||
/// </summary>
|
||||
public float rope_freq_scale;
|
||||
|
||||
/// <summary>
|
||||
/// called with a progress value between 0 and 1, pass NULL to disable
|
||||
/// </summary>
|
||||
public IntPtr progress_callback;
|
||||
|
||||
/// <summary>
|
||||
/// context pointer passed to the progress callback
|
||||
/// </summary>
|
||||
public IntPtr progress_callback_user_data;
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// if true, reduce VRAM usage at the cost of performance
|
||||
/// </summary>
|
||||
[MarshalAs(UnmanagedType.I1)]
|
||||
public bool low_vram;
|
||||
|
||||
/// <summary>
|
||||
/// use fp16 for KV cache
|
||||
/// </summary>
|
||||
[MarshalAs(UnmanagedType.I1)]
|
||||
public bool f16_kv;
|
||||
|
||||
/// <summary>
|
||||
/// the llama_eval() call computes all logits, not just the last one
|
||||
/// </summary>
|
||||
[MarshalAs(UnmanagedType.I1)]
|
||||
public bool logits_all;
|
||||
|
||||
/// <summary>
|
||||
/// only load the vocabulary, no weights
|
||||
/// </summary>
|
||||
[MarshalAs(UnmanagedType.I1)]
|
||||
public bool vocab_only;
|
||||
|
||||
/// <summary>
|
||||
/// use mmap if possible
|
||||
/// </summary>
|
||||
[MarshalAs(UnmanagedType.I1)]
|
||||
public bool use_mmap;
|
||||
|
||||
/// <summary>
|
||||
/// force system to keep model in RAM
|
||||
/// </summary>
|
||||
[MarshalAs(UnmanagedType.I1)]
|
||||
public bool use_mlock;
|
||||
|
||||
/// <summary>
|
||||
/// embedding mode only
|
||||
/// </summary>
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
using LLama.Exceptions;
|
||||
|
@ -29,7 +27,7 @@ namespace LLama.Native
|
|||
}
|
||||
private const string libraryName = "libllama";
|
||||
|
||||
[DllImport("libllama", EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
|
||||
[DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern bool llama_empty_call();
|
||||
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
|
@ -56,7 +54,10 @@ namespace LLama.Native
|
|||
/// <param name="params_"></param>
|
||||
/// <returns></returns>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern IntPtr llama_init_from_file(string path_model, LLamaContextParams params_);
|
||||
public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams params_);
|
||||
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams params_);
|
||||
|
||||
/// <summary>
|
||||
/// not great API - very likely to change.
|
||||
|
@ -65,6 +66,7 @@ namespace LLama.Native
|
|||
/// </summary>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern void llama_backend_init(bool numa);
|
||||
|
||||
/// <summary>
|
||||
/// Frees all allocated memory
|
||||
/// </summary>
|
||||
|
@ -72,6 +74,13 @@ namespace LLama.Native
|
|||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern void llama_free(IntPtr ctx);
|
||||
|
||||
/// <summary>
|
||||
/// Frees all allocated memory associated with a model
|
||||
/// </summary>
|
||||
/// <param name="model"></param>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern void llama_free_model(IntPtr model);
|
||||
|
||||
/// <summary>
|
||||
/// Apply a LoRA adapter to a loaded model
|
||||
/// path_base_model is the path to a higher quality model to use as a base for
|
||||
|
@ -79,13 +88,13 @@ namespace LLama.Native
|
|||
/// The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
||||
/// will be applied on top of the previous one
|
||||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <param name="model_ptr"></param>
|
||||
/// <param name="path_lora"></param>
|
||||
/// <param name="path_base_model"></param>
|
||||
/// <param name="n_threads"></param>
|
||||
/// <returns>Returns 0 on success</returns>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern int llama_apply_lora_from_file(SafeLLamaContextHandle ctx, string path_lora, string path_base_model, int n_threads);
|
||||
public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, string? path_base_model, int n_threads);
|
||||
|
||||
/// <summary>
|
||||
/// Returns the number of tokens in the KV cache
|
||||
|
@ -294,5 +303,20 @@ namespace LLama.Native
|
|||
/// <returns></returns>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern IntPtr llama_print_system_info();
|
||||
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern int llama_n_vocab_from_model(SafeLlamaModelHandle model);
|
||||
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern int llama_n_ctx_from_model(SafeLlamaModelHandle model);
|
||||
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern int llama_n_embd_from_model(SafeLlamaModelHandle model);
|
||||
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken);
|
||||
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,26 +1,61 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
using LLama.Exceptions;
|
||||
|
||||
namespace LLama.Native
|
||||
{
|
||||
public class SafeLLamaContextHandle: SafeLLamaHandleBase
|
||||
/// <summary>
|
||||
/// A safe wrapper around a llama_context
|
||||
/// </summary>
|
||||
public class SafeLLamaContextHandle
|
||||
: SafeLLamaHandleBase
|
||||
{
|
||||
protected SafeLLamaContextHandle()
|
||||
{
|
||||
}
|
||||
/// <summary>
|
||||
/// This field guarantees that a reference to the model is held for as long as this handle is held
|
||||
/// </summary>
|
||||
private SafeLlamaModelHandle? _model;
|
||||
|
||||
public SafeLLamaContextHandle(IntPtr handle)
|
||||
/// <summary>
|
||||
/// Create a new SafeLLamaContextHandle
|
||||
/// </summary>
|
||||
/// <param name="handle">pointer to an allocated llama_context</param>
|
||||
/// <param name="model">the model which this context was created from</param>
|
||||
public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model)
|
||||
: base(handle)
|
||||
{
|
||||
// Increment the model reference count while this context exists
|
||||
_model = model;
|
||||
var success = false;
|
||||
_model.DangerousAddRef(ref success);
|
||||
if (!success)
|
||||
throw new RuntimeError("Failed to increment model refcount");
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override bool ReleaseHandle()
|
||||
{
|
||||
// Decrement refcount on model
|
||||
_model?.DangerousRelease();
|
||||
_model = null;
|
||||
|
||||
NativeApi.llama_free(handle);
|
||||
SetHandle(IntPtr.Zero);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new llama_state for the given model
|
||||
/// </summary>
|
||||
/// <param name="model"></param>
|
||||
/// <param name="lparams"></param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="RuntimeError"></exception>
|
||||
public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams)
|
||||
{
|
||||
var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams);
|
||||
if (ctx_ptr == IntPtr.Zero)
|
||||
throw new RuntimeError("Failed to create context from model");
|
||||
|
||||
return new(ctx_ptr, model);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
|
||||
namespace LLama.Native
|
||||
{
|
||||
public abstract class SafeLLamaHandleBase: SafeHandle
|
||||
/// <summary>
|
||||
/// Base class for all llama handles to native resources
|
||||
/// </summary>
|
||||
public abstract class SafeLLamaHandleBase
|
||||
: SafeHandle
|
||||
{
|
||||
private protected SafeLLamaHandleBase()
|
||||
: base(IntPtr.Zero, ownsHandle: true)
|
||||
|
@ -24,8 +26,10 @@ namespace LLama.Native
|
|||
SetHandle(handle);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public override bool IsInvalid => handle == IntPtr.Zero;
|
||||
|
||||
/// <inheritdoc />
|
||||
public override string ToString()
|
||||
=> $"0x{handle.ToString("x16")}";
|
||||
}
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
using System;
|
||||
using System.Text;
|
||||
using LLama.Exceptions;
|
||||
|
||||
namespace LLama.Native
|
||||
{
|
||||
/// <summary>
|
||||
/// A reference to a set of llama model weights
|
||||
/// </summary>
|
||||
public class SafeLlamaModelHandle
|
||||
: SafeLLamaHandleBase
|
||||
{
|
||||
/// <summary>
|
||||
/// Total number of tokens in vocabulary of this model
|
||||
/// </summary>
|
||||
public int VocabCount { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Total number of tokens in the context
|
||||
/// </summary>
|
||||
public int ContextSize { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Dimension of embedding vectors
|
||||
/// </summary>
|
||||
public int EmbeddingCount { get; set; }
|
||||
|
||||
internal SafeLlamaModelHandle(IntPtr handle)
|
||||
: base(handle)
|
||||
{
|
||||
VocabCount = NativeApi.llama_n_vocab_from_model(this);
|
||||
ContextSize = NativeApi.llama_n_ctx_from_model(this);
|
||||
EmbeddingCount = NativeApi.llama_n_embd_from_model(this);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
protected override bool ReleaseHandle()
|
||||
{
|
||||
NativeApi.llama_free_model(handle);
|
||||
SetHandle(IntPtr.Zero);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load a model from the given file path into memory
|
||||
/// </summary>
|
||||
/// <param name="modelPath"></param>
|
||||
/// <param name="lparams"></param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="RuntimeError"></exception>
|
||||
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams)
|
||||
{
|
||||
var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
|
||||
if (model_ptr == IntPtr.Zero)
|
||||
throw new RuntimeError($"Failed to load model {modelPath}.");
|
||||
|
||||
return new SafeLlamaModelHandle(model_ptr);
|
||||
}
|
||||
|
||||
#region LoRA
|
||||
/// <summary>
|
||||
/// Apply a LoRA adapter to a loaded model
|
||||
/// </summary>
|
||||
/// <param name="lora"></param>
|
||||
/// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the
|
||||
/// adapter. Can be NULL to use the current loaded model.</param>
|
||||
/// <param name="threads"></param>
|
||||
/// <exception cref="RuntimeError"></exception>
|
||||
public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1)
|
||||
{
|
||||
var err = NativeApi.llama_model_apply_lora_from_file(
|
||||
this,
|
||||
lora,
|
||||
string.IsNullOrEmpty(modelBase) ? null : modelBase,
|
||||
threads
|
||||
);
|
||||
|
||||
if (err != 0)
|
||||
throw new RuntimeError("Failed to apply lora adapter.");
|
||||
}
|
||||
#endregion
|
||||
|
||||
#region tokenize
|
||||
/// <summary>
|
||||
/// Convert a single llama token into string bytes
|
||||
/// </summary>
|
||||
/// <param name="llama_token"></param>
|
||||
/// <returns></returns>
|
||||
public ReadOnlySpan<byte> TokenToSpan(int llama_token)
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
var bytes = new ReadOnlySpan<byte>(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue);
|
||||
var terminator = bytes.IndexOf((byte)0);
|
||||
return bytes.Slice(0, terminator);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Convert a single llama token into a string
|
||||
/// </summary>
|
||||
/// <param name="llama_token"></param>
|
||||
/// <param name="encoding">Encoding to use to decode the bytes into a string</param>
|
||||
/// <returns></returns>
|
||||
public string TokenToString(int llama_token, Encoding encoding)
|
||||
{
|
||||
var span = TokenToSpan(llama_token);
|
||||
|
||||
if (span.Length == 0)
|
||||
return "";
|
||||
|
||||
unsafe
|
||||
{
|
||||
fixed (byte* ptr = &span[0])
|
||||
{
|
||||
return encoding.GetString(ptr, span.Length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Convert a string of text into tokens
|
||||
/// </summary>
|
||||
/// <param name="text"></param>
|
||||
/// <param name="add_bos"></param>
|
||||
/// <param name="encoding"></param>
|
||||
/// <returns></returns>
|
||||
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
|
||||
{
|
||||
// Convert string to bytes, adding one extra byte to the end (null terminator)
|
||||
var bytesCount = encoding.GetByteCount(text);
|
||||
var bytes = new byte[bytesCount + 1];
|
||||
unsafe
|
||||
{
|
||||
fixed (char* charPtr = text)
|
||||
fixed (byte* bytePtr = &bytes[0])
|
||||
{
|
||||
encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length);
|
||||
}
|
||||
}
|
||||
|
||||
unsafe
|
||||
{
|
||||
fixed (byte* bytesPtr = &bytes[0])
|
||||
{
|
||||
// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
|
||||
var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos);
|
||||
|
||||
// Tokenize again, this time outputting into an array of exactly the right size
|
||||
var tokens = new int[count];
|
||||
fixed (int* tokensPtr = &tokens[0])
|
||||
{
|
||||
count = NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos);
|
||||
return tokens;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endregion
|
||||
}
|
||||
}
|
|
@ -31,24 +31,12 @@ namespace LLama.OldVersion
|
|||
throw new FileNotFoundException($"The model file does not exist: {@params.model}");
|
||||
}
|
||||
|
||||
var ctx_ptr = NativeApi.llama_init_from_file(@params.model, lparams);
|
||||
|
||||
if (ctx_ptr == IntPtr.Zero)
|
||||
{
|
||||
throw new RuntimeError($"Failed to load model {@params.model}.");
|
||||
}
|
||||
|
||||
SafeLLamaContextHandle ctx = new(ctx_ptr);
|
||||
var model = SafeLlamaModelHandle.LoadFromFile(@params.model, lparams);
|
||||
var ctx = SafeLLamaContextHandle.Create(model, lparams);
|
||||
|
||||
if (!string.IsNullOrEmpty(@params.lora_adapter))
|
||||
{
|
||||
int err = NativeApi.llama_apply_lora_from_file(ctx, @params.lora_adapter,
|
||||
string.IsNullOrEmpty(@params.lora_base) ? null : @params.lora_base, @params.n_threads);
|
||||
if (err != 0)
|
||||
{
|
||||
throw new RuntimeError("Failed to apply lora adapter.");
|
||||
}
|
||||
}
|
||||
model.ApplyLoraFromFile(@params.lora_adapter, @params.lora_base, @params.n_threads);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
|
|
@ -28,40 +28,25 @@ namespace LLama
|
|||
lparams.logits_all = @params.Perplexity;
|
||||
lparams.embedding = @params.EmbeddingMode;
|
||||
lparams.low_vram = @params.LowVram;
|
||||
|
||||
if(@params.TensorSplits.Length != 1)
|
||||
|
||||
if (@params.TensorSplits.Length != 1)
|
||||
{
|
||||
throw new ArgumentException("Currently multi-gpu support is not supported by " +
|
||||
"both llama.cpp and LLamaSharp.");
|
||||
}
|
||||
lparams.tensor_split = new TensorSplits()
|
||||
{
|
||||
Item1 = @params.TensorSplits[0]
|
||||
};
|
||||
lparams.tensor_split = @params.TensorSplits;
|
||||
|
||||
if (!File.Exists(@params.ModelPath))
|
||||
{
|
||||
throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");
|
||||
}
|
||||
|
||||
var ctx_ptr = NativeApi.llama_init_from_file(@params.ModelPath, lparams);
|
||||
|
||||
if (ctx_ptr == IntPtr.Zero)
|
||||
{
|
||||
throw new RuntimeError($"Failed to load model {@params.ModelPath}.");
|
||||
}
|
||||
|
||||
SafeLLamaContextHandle ctx = new(ctx_ptr);
|
||||
var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
|
||||
var ctx = SafeLLamaContextHandle.Create(model, lparams);
|
||||
|
||||
if (!string.IsNullOrEmpty(@params.LoraAdapter))
|
||||
{
|
||||
int err = NativeApi.llama_apply_lora_from_file(ctx, @params.LoraAdapter,
|
||||
string.IsNullOrEmpty(@params.LoraBase) ? null : @params.LoraBase, @params.Threads);
|
||||
if (err != 0)
|
||||
{
|
||||
throw new RuntimeError("Failed to apply lora adapter.");
|
||||
}
|
||||
}
|
||||
model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
@ -78,7 +63,7 @@ namespace LLama
|
|||
return res.Take(n);
|
||||
}
|
||||
|
||||
public unsafe static Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
|
||||
public static unsafe Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
|
||||
{
|
||||
var logits = NativeApi.llama_get_logits(ctx);
|
||||
return new Span<float>(logits, length);
|
||||
|
|
Loading…
Reference in New Issue