204 lines
7.3 KiB
C#
204 lines
7.3 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.Text;
|
|
using LLama.Exceptions;
|
|
|
|
namespace LLama.Native
|
|
{
|
|
/// <summary>
|
|
/// A reference to a set of llama model weights
|
|
/// </summary>
|
|
public sealed class SafeLlamaModelHandle
|
|
: SafeLLamaHandleBase
|
|
{
|
|
/// <summary>
|
|
/// Total number of tokens in vocabulary of this model
|
|
/// </summary>
|
|
public int VocabCount { get; }
|
|
|
|
/// <summary>
|
|
/// Total number of tokens in the context
|
|
/// </summary>
|
|
public int ContextSize { get; }
|
|
|
|
/// <summary>
|
|
/// Dimension of embedding vectors
|
|
/// </summary>
|
|
public int EmbeddingSize { get; }
|
|
|
|
/// <summary>
|
|
/// Get the size of this model in bytes
|
|
/// </summary>
|
|
public ulong SizeInBytes { get; }
|
|
|
|
/// <summary>
|
|
/// Get the number of parameters in this model
|
|
/// </summary>
|
|
public ulong ParameterCount { get; }
|
|
|
|
internal SafeLlamaModelHandle(IntPtr handle)
|
|
: base(handle)
|
|
{
|
|
VocabCount = NativeApi.llama_n_vocab(this);
|
|
ContextSize = NativeApi.llama_n_ctx_train(this);
|
|
EmbeddingSize = NativeApi.llama_n_embd(this);
|
|
SizeInBytes = NativeApi.llama_model_size(this);
|
|
ParameterCount = NativeApi.llama_model_n_params(this);
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
protected override bool ReleaseHandle()
|
|
{
|
|
NativeApi.llama_free_model(DangerousGetHandle());
|
|
SetHandle(IntPtr.Zero);
|
|
return true;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Load a model from the given file path into memory
|
|
/// </summary>
|
|
/// <param name="modelPath"></param>
|
|
/// <param name="lparams"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="RuntimeError"></exception>
|
|
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams)
|
|
{
|
|
var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
|
|
if (model_ptr == IntPtr.Zero)
|
|
throw new RuntimeError($"Failed to load model {modelPath}.");
|
|
|
|
return new SafeLlamaModelHandle(model_ptr);
|
|
}
|
|
|
|
#region LoRA
|
|
|
|
/// <summary>
|
|
/// Apply a LoRA adapter to a loaded model
|
|
/// </summary>
|
|
/// <param name="lora"></param>
|
|
/// <param name="scale"></param>
|
|
/// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the
|
|
/// adapter. Can be NULL to use the current loaded model.</param>
|
|
/// <param name="threads"></param>
|
|
/// <exception cref="RuntimeError"></exception>
|
|
public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null)
|
|
{
|
|
var err = NativeApi.llama_model_apply_lora_from_file(
|
|
this,
|
|
lora,
|
|
scale,
|
|
string.IsNullOrEmpty(modelBase) ? null : modelBase,
|
|
threads ?? Math.Max(1, Environment.ProcessorCount / 2)
|
|
);
|
|
|
|
if (err != 0)
|
|
throw new RuntimeError("Failed to apply lora adapter.");
|
|
}
|
|
#endregion
|
|
|
|
#region tokenize
|
|
/// <summary>
|
|
/// Convert a single llama token into bytes
|
|
/// </summary>
|
|
/// <param name="llama_token">Token to decode</param>
|
|
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
|
|
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
|
|
public int TokenToSpan(int llama_token, Span<byte> dest)
|
|
{
|
|
unsafe
|
|
{
|
|
fixed (byte* destPtr = dest)
|
|
{
|
|
var length = NativeApi.llama_token_to_piece(this, llama_token, destPtr, dest.Length);
|
|
return Math.Abs(length);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Convert a sequence of tokens into characters.
|
|
/// </summary>
|
|
/// <param name="tokens"></param>
|
|
/// <param name="dest"></param>
|
|
/// <param name="encoding"></param>
|
|
/// <returns>The section of the span which has valid data in it.
|
|
/// If there was insufficient space in the output span this will be
|
|
/// filled with as many characters as possible, starting from the _last_ token.
|
|
/// </returns>
|
|
[Obsolete("Use a StreamingTokenDecoder instead")]
|
|
internal Span<char> TokensToSpan(IReadOnlyList<int> tokens, Span<char> dest, Encoding encoding)
|
|
{
|
|
var decoder = new StreamingTokenDecoder(encoding, this);
|
|
|
|
foreach (var token in tokens)
|
|
decoder.Add(token);
|
|
|
|
var str = decoder.Read();
|
|
|
|
if (str.Length < dest.Length)
|
|
{
|
|
str.AsSpan().CopyTo(dest);
|
|
return dest.Slice(0, str.Length);
|
|
}
|
|
else
|
|
{
|
|
str.AsSpan().Slice(str.Length - dest.Length).CopyTo(dest);
|
|
return dest;
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Convert a string of text into tokens
|
|
/// </summary>
|
|
/// <param name="text"></param>
|
|
/// <param name="add_bos"></param>
|
|
/// <param name="encoding"></param>
|
|
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
|
|
/// <returns></returns>
|
|
public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
|
|
{
|
|
// Convert string to bytes, adding one extra byte to the end (null terminator)
|
|
var bytesCount = encoding.GetByteCount(text);
|
|
var bytes = new byte[bytesCount + 1];
|
|
unsafe
|
|
{
|
|
fixed (char* charPtr = text)
|
|
fixed (byte* bytePtr = &bytes[0])
|
|
{
|
|
encoding.GetBytes(charPtr, text.Length, bytePtr, bytes.Length);
|
|
}
|
|
}
|
|
|
|
unsafe
|
|
{
|
|
fixed (byte* bytesPtr = &bytes[0])
|
|
{
|
|
// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
|
|
var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (int*)IntPtr.Zero, 0, add_bos, special);
|
|
|
|
// Tokenize again, this time outputting into an array of exactly the right size
|
|
var tokens = new int[count];
|
|
fixed (int* tokensPtr = &tokens[0])
|
|
{
|
|
NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special);
|
|
return tokens;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
#endregion
|
|
|
|
#region context
|
|
/// <summary>
|
|
/// Create a new context for this model
|
|
/// </summary>
|
|
/// <param name="params"></param>
|
|
/// <returns></returns>
|
|
public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
|
|
{
|
|
return SafeLLamaContextHandle.Create(this, @params);
|
|
}
|
|
#endregion
|
|
}
|
|
}
|