Merge pull request #64 from martindevans/new_llama_state_loading_mechanism

Low level new loading system
This commit is contained in:
Rinne 2023-08-05 08:47:28 +08:00 committed by GitHub
commit 1d29b240b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 290 additions and 57 deletions

View File

@ -13,65 +13,101 @@ namespace LLama.Native
/// RNG seed, -1 for random
/// </summary>
public int seed;
/// <summary>
/// text context
/// </summary>
public int n_ctx;
/// <summary>
/// prompt processing batch size
/// </summary>
public int n_batch;
/// <summary>
/// grouped-query attention (TEMP - will be moved to model hparams)
/// </summary>
public int n_gqa;
/// <summary>
/// rms norm epsilon (TEMP - will be moved to model hparams)
/// </summary>
public float rms_norm_eps;
/// <summary>
/// number of layers to store in VRAM
/// </summary>
public int n_gpu_layers;
/// <summary>
/// the GPU that is used for scratch and small tensors
/// </summary>
public int main_gpu;
/// <summary>
/// how to split layers across multiple GPUs
/// </summary>
public TensorSplits tensor_split;
public float[] tensor_split;
/// <summary>
/// ref: https://github.com/ggerganov/llama.cpp/pull/2054
/// RoPE base frequency
/// </summary>
public float rope_freq_base;
/// <summary>
/// ref: https://github.com/ggerganov/llama.cpp/pull/2054
/// RoPE frequency scaling factor
/// </summary>
public float rope_freq_scale;
/// <summary>
/// called with a progress value between 0 and 1, pass NULL to disable
/// </summary>
public IntPtr progress_callback;
/// <summary>
/// context pointer passed to the progress callback
/// </summary>
public IntPtr progress_callback_user_data;
/// <summary>
/// if true, reduce VRAM usage at the cost of performance
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool low_vram;
/// <summary>
/// use fp16 for KV cache
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool f16_kv;
/// <summary>
/// the llama_eval() call computes all logits, not just the last one
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool logits_all;
/// <summary>
/// only load the vocabulary, no weights
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool vocab_only;
/// <summary>
/// use mmap if possible
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool use_mmap;
/// <summary>
/// force system to keep model in RAM
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool use_mlock;
/// <summary>
/// embedding mode only
/// </summary>

View File

@ -1,6 +1,4 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
@ -29,7 +27,7 @@ namespace LLama.Native
}
private const string libraryName = "libllama";
[DllImport("libllama", EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
[DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_empty_call();
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
@ -56,7 +54,10 @@ namespace LLama.Native
/// <param name="params_"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_init_from_file(string path_model, LLamaContextParams params_);
public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams params_);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams params_);
/// <summary>
/// not great API - very likely to change.
@ -65,6 +66,7 @@ namespace LLama.Native
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_backend_init(bool numa);
/// <summary>
/// Frees all allocated memory
/// </summary>
@ -72,6 +74,13 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_free(IntPtr ctx);
/// <summary>
/// Frees all allocated memory associated with a model
/// </summary>
/// <param name="model"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_free_model(IntPtr model);
/// <summary>
/// Apply a LoRA adapter to a loaded model
/// path_base_model is the path to a higher quality model to use as a base for
@ -79,13 +88,13 @@ namespace LLama.Native
/// The model needs to be reloaded before applying a new adapter, otherwise the adapter
/// will be applied on top of the previous one
/// </summary>
/// <param name="ctx"></param>
/// <param name="model_ptr"></param>
/// <param name="path_lora"></param>
/// <param name="path_base_model"></param>
/// <param name="n_threads"></param>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_apply_lora_from_file(SafeLLamaContextHandle ctx, string path_lora, string path_base_model, int n_threads);
public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, string? path_base_model, int n_threads);
/// <summary>
/// Returns the number of tokens in the KV cache
@ -294,5 +303,20 @@ namespace LLama.Native
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_print_system_info();
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_vocab_from_model(SafeLlamaModelHandle model);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_ctx_from_model(SafeLlamaModelHandle model);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_embd_from_model(SafeLlamaModelHandle model);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos);
}
}

View File

@ -1,26 +1,61 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
namespace LLama.Native
{
public class SafeLLamaContextHandle: SafeLLamaHandleBase
/// <summary>
/// A safe wrapper around a llama_context
/// </summary>
public class SafeLLamaContextHandle
: SafeLLamaHandleBase
{
protected SafeLLamaContextHandle()
{
}
/// <summary>
/// This field guarantees that a reference to the model is held for as long as this handle is held
/// </summary>
private SafeLlamaModelHandle? _model;
public SafeLLamaContextHandle(IntPtr handle)
/// <summary>
/// Create a new SafeLLamaContextHandle
/// </summary>
/// <param name="handle">pointer to an allocated llama_context</param>
/// <param name="model">the model which this context was created from</param>
public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model)
: base(handle)
{
// Increment the model reference count while this context exists
_model = model;
var success = false;
_model.DangerousAddRef(ref success);
if (!success)
throw new RuntimeError("Failed to increment model refcount");
}
/// <inheritdoc />
protected override bool ReleaseHandle()
{
// Decrement refcount on model
_model?.DangerousRelease();
_model = null;
NativeApi.llama_free(handle);
SetHandle(IntPtr.Zero);
return true;
}
/// <summary>
/// Create a new llama_state for the given model
/// </summary>
/// <param name="model"></param>
/// <param name="lparams"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams)
{
var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams);
if (ctx_ptr == IntPtr.Zero)
throw new RuntimeError("Failed to create context from model");
return new(ctx_ptr, model);
}
}
}

View File

@ -1,11 +1,13 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
namespace LLama.Native
{
public abstract class SafeLLamaHandleBase: SafeHandle
/// <summary>
/// Base class for all llama handles to native resources
/// </summary>
public abstract class SafeLLamaHandleBase
: SafeHandle
{
private protected SafeLLamaHandleBase()
: base(IntPtr.Zero, ownsHandle: true)
@ -24,8 +26,10 @@ namespace LLama.Native
SetHandle(handle);
}
/// <inheritdoc />
public override bool IsInvalid => handle == IntPtr.Zero;
/// <inheritdoc />
public override string ToString()
=> $"0x{handle.ToString("x16")}";
}

View File

@ -0,0 +1,161 @@
using System;
using System.Text;
using LLama.Exceptions;
namespace LLama.Native
{
/// <summary>
/// A reference to a set of llama model weights
/// </summary>
public class SafeLlamaModelHandle
: SafeLLamaHandleBase
{
/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount { get; set; }
/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize { get; set; }
/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingCount { get; set; }
internal SafeLlamaModelHandle(IntPtr handle)
: base(handle)
{
VocabCount = NativeApi.llama_n_vocab_from_model(this);
ContextSize = NativeApi.llama_n_ctx_from_model(this);
EmbeddingCount = NativeApi.llama_n_embd_from_model(this);
}
/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_free_model(handle);
SetHandle(IntPtr.Zero);
return true;
}
/// <summary>
/// Load a model from the given file path into memory
/// </summary>
/// <param name="modelPath"></param>
/// <param name="lparams"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams)
{
var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
if (model_ptr == IntPtr.Zero)
throw new RuntimeError($"Failed to load model {modelPath}.");
return new SafeLlamaModelHandle(model_ptr);
}
#region LoRA
/// <summary>
/// Apply a LoRA adapter to a loaded model
/// </summary>
/// <param name="lora"></param>
/// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the
/// adapter. Can be NULL to use the current loaded model.</param>
/// <param name="threads"></param>
/// <exception cref="RuntimeError"></exception>
public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1)
{
var err = NativeApi.llama_model_apply_lora_from_file(
this,
lora,
string.IsNullOrEmpty(modelBase) ? null : modelBase,
threads
);
if (err != 0)
throw new RuntimeError("Failed to apply lora adapter.");
}
#endregion
#region tokenize
/// <summary>
/// Convert a single llama token into string bytes
/// </summary>
/// <param name="llama_token"></param>
/// <returns></returns>
public ReadOnlySpan<byte> TokenToSpan(int llama_token)
{
unsafe
{
var bytes = new ReadOnlySpan<byte>(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue);
var terminator = bytes.IndexOf((byte)0);
return bytes.Slice(0, terminator);
}
}
/// <summary>
/// Convert a single llama token into a string
/// </summary>
/// <param name="llama_token"></param>
/// <param name="encoding">Encoding to use to decode the bytes into a string</param>
/// <returns></returns>
public string TokenToString(int llama_token, Encoding encoding)
{
var span = TokenToSpan(llama_token);
if (span.Length == 0)
return "";
unsafe
{
fixed (byte* ptr = &span[0])
{
return encoding.GetString(ptr, span.Length);
}
}
}
/// <summary>
/// Convert a string of text into tokens
/// </summary>
/// <param name="text"></param>
/// <param name="add_bos"></param>
/// <param name="encoding"></param>
/// <returns></returns>
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
{
// Convert string to bytes, adding one extra byte to the end (null terminator)
var bytesCount = encoding.GetByteCount(text);
var bytes = new byte[bytesCount + 1];
unsafe
{
fixed (char* charPtr = text)
fixed (byte* bytePtr = &bytes[0])
{
encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length);
}
}
unsafe
{
fixed (byte* bytesPtr = &bytes[0])
{
// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos);
// Tokenize again, this time outputting into an array of exactly the right size
var tokens = new int[count];
fixed (int* tokensPtr = &tokens[0])
{
count = NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos);
return tokens;
}
}
}
}
#endregion
}
}

View File

@ -31,24 +31,12 @@ namespace LLama.OldVersion
throw new FileNotFoundException($"The model file does not exist: {@params.model}");
}
var ctx_ptr = NativeApi.llama_init_from_file(@params.model, lparams);
if (ctx_ptr == IntPtr.Zero)
{
throw new RuntimeError($"Failed to load model {@params.model}.");
}
SafeLLamaContextHandle ctx = new(ctx_ptr);
var model = SafeLlamaModelHandle.LoadFromFile(@params.model, lparams);
var ctx = SafeLLamaContextHandle.Create(model, lparams);
if (!string.IsNullOrEmpty(@params.lora_adapter))
{
int err = NativeApi.llama_apply_lora_from_file(ctx, @params.lora_adapter,
string.IsNullOrEmpty(@params.lora_base) ? null : @params.lora_base, @params.n_threads);
if (err != 0)
{
throw new RuntimeError("Failed to apply lora adapter.");
}
}
model.ApplyLoraFromFile(@params.lora_adapter, @params.lora_base, @params.n_threads);
return ctx;
}

View File

@ -28,40 +28,25 @@ namespace LLama
lparams.logits_all = @params.Perplexity;
lparams.embedding = @params.EmbeddingMode;
lparams.low_vram = @params.LowVram;
if(@params.TensorSplits.Length != 1)
if (@params.TensorSplits.Length != 1)
{
throw new ArgumentException("Currently multi-gpu support is not supported by " +
"both llama.cpp and LLamaSharp.");
}
lparams.tensor_split = new TensorSplits()
{
Item1 = @params.TensorSplits[0]
};
lparams.tensor_split = @params.TensorSplits;
if (!File.Exists(@params.ModelPath))
{
throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");
}
var ctx_ptr = NativeApi.llama_init_from_file(@params.ModelPath, lparams);
if (ctx_ptr == IntPtr.Zero)
{
throw new RuntimeError($"Failed to load model {@params.ModelPath}.");
}
SafeLLamaContextHandle ctx = new(ctx_ptr);
var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
var ctx = SafeLLamaContextHandle.Create(model, lparams);
if (!string.IsNullOrEmpty(@params.LoraAdapter))
{
int err = NativeApi.llama_apply_lora_from_file(ctx, @params.LoraAdapter,
string.IsNullOrEmpty(@params.LoraBase) ? null : @params.LoraBase, @params.Threads);
if (err != 0)
{
throw new RuntimeError("Failed to apply lora adapter.");
}
}
model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
return ctx;
}
@ -78,7 +63,7 @@ namespace LLama
return res.Take(n);
}
public unsafe static Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
public static unsafe Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
{
var logits = NativeApi.llama_get_logits(ctx);
return new Span<float>(logits, length);