Merge pull request #185 from martindevans/wip_major_api_change

Major llama.cpp API Change
This commit is contained in:
Martin Evans 2023-10-18 20:50:32 +01:00 committed by GitHub
commit d8434ea9d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 1437 additions and 919 deletions

View File

@ -8,7 +8,7 @@ namespace LLama.Examples.NewVersion
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim();
var prompt = (await File.ReadAllTextAsync("Assets/chat-with-bob.txt")).Trim();
var parameters = new ModelParams(modelPath)
{
@ -50,7 +50,7 @@ namespace LLama.Examples.NewVersion
Console.ForegroundColor = ConsoleColor.White;
ex.Context.Dispose();
ex = new(new LLamaContext(parameters));
ex = new(new LLamaContext(model, parameters));
session = new ChatSession(ex);
session.LoadSession(statePath);

View File

@ -1,13 +1,7 @@
using System.Reflection.Metadata;
using System.Security.Cryptography;
using System.Text;
using LLama.Abstractions;
using System.Security.Cryptography;
using LLama.Common;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using Microsoft.SemanticKernel.AI.TextCompletion;
using LLamaSharp.SemanticKernel.ChatCompletion;
using LLamaSharp.SemanticKernel.TextCompletion;
namespace LLama.Examples.NewVersion
{
@ -22,7 +16,7 @@ namespace LLama.Examples.NewVersion
// Load weights into memory
var parameters = new ModelParams(modelPath)
{
Seed = RandomNumberGenerator.GetInt32(int.MaxValue),
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)),
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);

View File

@ -18,7 +18,7 @@ namespace LLama.Examples.NewVersion
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
var seed = 1337;
var seed = 1337u;
// Load weights into memory
var parameters = new ModelParams(modelPath)
{

View File

@ -18,7 +18,7 @@ namespace LLama.Examples.NewVersion
// Load weights into memory
var parameters = new ModelParams(modelPath)
{
Seed = RandomNumberGenerator.GetInt32(int.MaxValue),
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue))
};
using var model = LLamaWeights.LoadFromFile(parameters);
var ex = new StatelessExecutor(model, parameters);

View File

@ -15,7 +15,7 @@ namespace LLama.Examples.NewVersion
// Load weights into memory
var @params = new ModelParams(modelPath)
{
Seed = RandomNumberGenerator.GetInt32(int.MaxValue)
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue))
};
using var weights = LLamaWeights.LoadFromFile(@params);

View File

@ -1,4 +1,5 @@
using LLama.Examples.NewVersion;
using LLama.Native;
Console.WriteLine("======================================================================================================");
@ -7,7 +8,7 @@ Console.WriteLine(" __ __ ____ _
Console.WriteLine("======================================================================================================");
NativeApi.llama_empty_call();
Console.WriteLine();
await NewVersionTestRunner.Run();

View File

@ -27,36 +27,8 @@ namespace LLama.Unittest
public void BasicModelProperties()
{
Assert.Equal(32000, _model.VocabCount);
Assert.Equal(2048, _model.ContextSize);
Assert.Equal(4096, _model.ContextSize);
Assert.Equal(4096, _model.EmbeddingSize);
Assert.Equal(Encoding.UTF8, _model.Encoding);
}
[Fact]
public void CloneContext()
{
var original = _model.CreateContext(_params);
// Evaluate something (doesn't matter what, as long as it begins with token 1)
original.Eval(new[] { 1, 42, 321 }, 0);
// Clone current state
var clone = original.Clone();
// Now evaluate something more
var reply1a = original.Eval(new[] { 4, 5, 6 }, 3);
var reply2a = original.Eval(new[] { 7, 8, 9 }, 6);
// Assert that the context replied differently each time
Assert.NotEqual(reply1a, reply2a);
// Give the same prompts to the cloned state
var reply1b = clone.Eval(new[] { 4, 5, 6 }, 3);
var reply2b = clone.Eval(new[] { 7, 8, 9 }, 6);
// Assert that the cloned context replied in the same way as originally
Assert.Equal(reply1a, reply1b);
Assert.Equal(reply2a, reply2b);
}
}
}

View File

@ -2,7 +2,7 @@
namespace LLama.Unittest
{
public class LLamaContextTests
public sealed class LLamaContextTests
: IDisposable
{
private readonly LLamaWeights _weights;
@ -30,7 +30,6 @@ namespace LLama.Unittest
Assert.Equal(768, _context.ContextSize);
Assert.Equal(4096, _context.EmbeddingSize);
Assert.Equal(32000, _context.VocabCount);
Assert.Equal(0, _context.KVCacheTokenCount);
}
[Fact]

View File

@ -13,7 +13,6 @@ namespace LLama.Unittest
{
BatchSize = 17,
ContextSize = 42,
LoraAdapter = "adapter",
Seed = 42,
GpuLayerCount = 111
};
@ -31,9 +30,13 @@ namespace LLama.Unittest
{
BatchSize = 17,
ContextSize = 42,
LoraAdapter = "adapter",
Seed = 42,
GpuLayerCount = 111
GpuLayerCount = 111,
LoraAdapters =
{
new("abc", 1),
new("def", 0)
}
};
var settings = new Newtonsoft.Json.JsonSerializerSettings();

View File

@ -16,7 +16,7 @@ namespace LLama.Unittest
_params = new ModelParams(Constants.ModelPath)
{
ContextSize = 60,
Seed = 1754
Seed = 1754,
};
_weights = LLamaWeights.LoadFromFile(_params);
}
@ -48,13 +48,13 @@ namespace LLama.Unittest
{
var executor = new StatelessExecutor(_weights, _params);
const string question = " Question. why is a cat the best pet?\nAnswer: ";
const string question = " Question. cats or dogs?\nAnswer: ";
// The context size is set to 60. Generate more than that, forcing it to generate a coherent response
// with a modified context
var @params = new InferenceParams()
{
MaxTokens = 100,
MaxTokens = 65,
TokensKeep = question.Length,
};

View File

@ -27,7 +27,7 @@ public sealed class TokenTests
[Fact]
public void TokensEndWith()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);
var result = tokens.TokensEndsWithAnyString(new[]
{
@ -41,7 +41,7 @@ public sealed class TokenTests
[Fact]
public void TokensEndSubstring()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);
var result = tokens.TokensEndsWithAnyString((IList<string>)new[]
{
@ -53,7 +53,7 @@ public sealed class TokenTests
[Fact]
public void TokensNotEndWith()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);
var result = tokens.TokensEndsWithAnyString((IList<string>)new[]
{
@ -67,7 +67,7 @@ public sealed class TokenTests
[Fact]
public void TokensNotEndWithNothing()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8);
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);
var result = tokens.TokensEndsWithAnyString((IList<string>)Array.Empty<string>(), _model.NativeHandle, Encoding.UTF8);
Assert.False(result);

View File

@ -4,7 +4,7 @@ using LLama.Abstractions;
namespace LLama.Web.Common
{
public class ModelOptions
: IModelParams
: ILLamaParams
{
public string Name { get; set; }
@ -14,7 +14,7 @@ namespace LLama.Web.Common
/// <summary>
/// Model context size (n_ctx)
/// </summary>
public int ContextSize { get; set; } = 512;
public uint ContextSize { get; set; } = 512;
/// <summary>
/// the GPU that is used for scratch and small tensors
/// </summary>
@ -30,7 +30,7 @@ namespace LLama.Web.Common
/// <summary>
/// Seed for the random number generator (seed)
/// </summary>
public int Seed { get; set; } = 1686349486;
public uint Seed { get; set; } = 1686349486;
/// <summary>
/// Use f16 instead of f32 for memory kv (memory_f16)
/// </summary>
@ -51,26 +51,31 @@ namespace LLama.Web.Common
/// Model path (model)
/// </summary>
public string ModelPath { get; set; }
/// <summary>
/// model alias
/// List of LoRAs to apply
/// </summary>
public string ModelAlias { get; set; } = "unknown";
public AdapterCollection LoraAdapters { get; set; } = new();
/// <summary>
/// base model path for the lora adapter (lora_base)
/// </summary>
public string LoraBase { get; set; } = string.Empty;
/// <summary>
/// lora adapter path (lora_adapter)
/// Number of threads (null = autodetect) (n_threads)
/// </summary>
public string LoraAdapter { get; set; } = string.Empty;
/// <summary>
/// base model path for the lora adapter (lora_base)
/// </summary>
public string LoraBase { get; set; } = string.Empty;
/// <summary>
/// Number of threads (-1 = autodetect) (n_threads)
/// </summary>
public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1);
/// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// </summary>
public int BatchSize { get; set; } = 512;
public uint? Threads { get; set; }
/// <summary>
/// Number of threads to use for batch processing (null = autodetect) (n_threads)
/// </summary>
public uint? BatchThreads { get; set; }
/// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// </summary>
public uint BatchSize { get; set; } = 512;
/// <summary>
/// Whether to convert eos to newline during the inference.
@ -107,5 +112,10 @@ namespace LLama.Web.Common
/// The encoding to use for models
/// </summary>
public Encoding Encoding { get; set; } = Encoding.UTF8;
/// <summary>
/// Load vocab only (no weights)
/// </summary>
public bool VocabOnly { get; set; }
}
}

View File

@ -3,7 +3,6 @@ using LLama.Web.Common;
using LLama.Web.Models;
using Microsoft.Extensions.Options;
using System.Collections.Concurrent;
using System.Drawing;
namespace LLama.Web.Services
{
@ -50,15 +49,16 @@ namespace LLama.Web.Services
if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances)
return Task.FromResult(ServiceResult.FromError<ModelSession>("Maximum model instances reached"));
// Create model
var llamaModel = new LLamaContext(modelOption);
// Load weights
// todo: it would be better to have a central service which loads weights and shares them between all contexts that need them!
using var weights = LLamaWeights.LoadFromFile(modelOption);
// Create executor
ILLamaExecutor executor = executorType switch
{
LLamaExecutorType.Interactive => new InteractiveExecutor(llamaModel),
LLamaExecutorType.Instruct => new InstructExecutor(llamaModel),
LLamaExecutorType.Stateless => new StatelessExecutor(llamaModel),
LLamaExecutorType.Interactive => new InteractiveExecutor(new LLamaContext(weights, modelOption)), //todo: properly dispose of LLamaContext
LLamaExecutorType.Instruct => new InstructExecutor(new LLamaContext(weights, modelOption)), //todo: properly dispose of LLamaContext
LLamaExecutorType.Stateless => new StatelessExecutor(weights, modelOption),
_ => default
};

View File

@ -16,10 +16,15 @@ public class StatefulChatService : IDisposable
public StatefulChatService(IConfiguration configuration)
{
_context = new LLamaContext(new Common.ModelParams(configuration["ModelPath"])
var @params = new Common.ModelParams(configuration["ModelPath"])
{
ContextSize = 512
});
ContextSize = 512,
};
// todo: share weights from a central service
using var weights = LLamaWeights.LoadFromFile(@params);
_context = new LLamaContext(weights, @params);
_session = new ChatSession(new InteractiveExecutor(_context));
}

View File

@ -12,10 +12,16 @@ namespace LLama.WebAPI.Services
public StatelessChatService(IConfiguration configuration)
{
_context = new LLamaContext(new ModelParams(configuration["ModelPath"])
var @params = new Common.ModelParams(configuration["ModelPath"])
{
ContextSize = 512,
});
};
// todo: share weights from a central service
using var weights = LLamaWeights.LoadFromFile(@params);
_context = new LLamaContext(weights, @params);
// TODO: replace with a stateless executor
_session = new ChatSession(new InteractiveExecutor(_context))
.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8))

View File

@ -0,0 +1,70 @@
using System.Text;
namespace LLama.Abstractions;
/// <summary>
/// The parameters for initializing a LLama context from a model.
/// </summary>
public interface IContextParams
{
/// <summary>
/// Model context size (n_ctx)
/// </summary>
uint ContextSize { get; set; }
/// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// </summary>
uint BatchSize { get; set; }
/// <summary>
/// Seed for the random number generator (seed)
/// </summary>
uint Seed { get; set; }
/// <summary>
/// Use f16 instead of f32 for memory kv (memory_f16)
/// </summary>
bool UseFp16Memory { get; set; }
/// <summary>
/// Compute perplexity over the prompt (perplexity)
/// </summary>
bool Perplexity { get; set; }
/// <summary>
/// Whether to use embedding mode. (embedding) Note that if this is set to true,
/// The LLamaModel won't produce text response anymore.
/// </summary>
bool EmbeddingMode { get; set; }
/// <summary>
/// RoPE base frequency
/// </summary>
float RopeFrequencyBase { get; set; }
/// <summary>
/// RoPE frequency scaling factor
/// </summary>
float RopeFrequencyScale { get; set; }
/// <summary>
/// Use experimental mul_mat_q kernels
/// </summary>
bool MulMatQ { get; set; }
/// <summary>
/// The encoding to use for models
/// </summary>
Encoding Encoding { get; set; }
/// <summary>
/// Number of threads (null = autodetect) (n_threads)
/// </summary>
uint? Threads { get; set; }
/// <summary>
/// Number of threads to use for batch processing (null = autodetect) (n_threads)
/// </summary>
uint? BatchThreads { get; set; }
}

View File

@ -36,7 +36,7 @@ namespace LLama.Abstractions
/// </summary>
public int TopK { get; set; }
/// <summary>
/// <summary>llama_eval
/// 1.0 = disabled
/// </summary>
public float TopP { get; set; }

View File

@ -0,0 +1,11 @@
namespace LLama.Abstractions
{
/// <summary>
/// Convenience interface for implementing both type of parameters.
/// </summary>
/// <remarks>Mostly exists for backwards compatibility reasons, when these two were not split.</remarks>
public interface ILLamaParams
: IModelParams, IContextParams
{
}
}

View File

@ -1,4 +1,6 @@
using System.Text;
using System;
using System.Collections.Generic;
using System.Linq;
namespace LLama.Abstractions
{
@ -7,36 +9,16 @@ namespace LLama.Abstractions
/// </summary>
public interface IModelParams
{
/// <summary>
/// Model context size (n_ctx)
/// </summary>
int ContextSize { get; set; }
/// <summary>
/// the GPU that is used for scratch and small tensors
/// </summary>
int MainGpu { get; set; }
/// <summary>
/// if true, reduce VRAM usage at the cost of performance
/// </summary>
bool LowVram { get; set; }
/// <summary>
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
/// </summary>
int GpuLayerCount { get; set; }
/// <summary>
/// Seed for the random number generator (seed)
/// </summary>
int Seed { get; set; }
/// <summary>
/// Use f16 instead of f32 for memory kv (memory_f16)
/// </summary>
bool UseFp16Memory { get; set; }
/// <summary>
/// Use mmap for faster loads (use_mmap)
/// </summary>
@ -47,41 +29,15 @@ namespace LLama.Abstractions
/// </summary>
bool UseMemoryLock { get; set; }
/// <summary>
/// Compute perplexity over the prompt (perplexity)
/// </summary>
bool Perplexity { get; set; }
/// <summary>
/// Model path (model)
/// </summary>
string ModelPath { get; set; }
/// <summary>
/// lora adapter path (lora_adapter)
/// </summary>
string LoraAdapter { get; set; }
/// <summary>
/// base model path for the lora adapter (lora_base)
/// </summary>
string LoraBase { get; set; }
/// <summary>
/// Number of threads (-1 = autodetect) (n_threads)
/// </summary>
int Threads { get; set; }
/// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// </summary>
int BatchSize { get; set; }
/// <summary>
/// Whether to use embedding mode. (embedding) Note that if this is set to true,
/// The LLamaModel won't produce text response anymore.
/// </summary>
bool EmbeddingMode { get; set; }
uint? Threads { get; set; }
/// <summary>
/// how split tensors should be distributed across GPUs
@ -89,23 +45,62 @@ namespace LLama.Abstractions
float[]? TensorSplits { get; set; }
/// <summary>
/// RoPE base frequency
/// Load vocab only (no weights)
/// </summary>
float RopeFrequencyBase { get; set; }
bool VocabOnly { get; set; }
/// <summary>
/// RoPE frequency scaling factor
/// List of LoRA adapters to apply
/// </summary>
float RopeFrequencyScale { get; set; }
AdapterCollection LoraAdapters { get; }
/// <summary>
/// Use experimental mul_mat_q kernels
/// base model path for the lora adapter (lora_base)
/// </summary>
bool MulMatQ { get; set; }
string LoraBase { get; set; }
}
/// <summary>
/// The encoding to use for models
/// </summary>
Encoding Encoding { get; set; }
/// <summary>
/// A LoRA adapter to apply to a model
/// </summary>
/// <param name="Path">Path to the LoRA file</param>
/// <param name="Scale">Strength of this LoRA</param>
public readonly record struct LoraAdapter(string Path, float Scale);
/// <summary>
/// A list of LoraAdapter objects
/// </summary>
public sealed class AdapterCollection
: List<LoraAdapter>, IEquatable<AdapterCollection>
{
/// <inheritdoc />
public bool Equals(AdapterCollection? other)
{
if (other == null)
return false;
return this.SequenceEqual(other);
}
/// <inheritdoc/>
public override bool Equals(object? obj)
{
return Equals(obj as AdapterCollection);
}
/// <inheritdoc/>
public override int GetHashCode()
{
unchecked
{
var hash = 17;
for (var i = 0; i < Count; i++)
{
hash += this[i].GetHashCode();
hash *= 7823;
}
return hash;
}
}
}
}

View File

@ -10,20 +10,17 @@ namespace LLama.Common
/// The parameters for initializing a LLama model.
/// </summary>
public record ModelParams
: IModelParams
: ILLamaParams
{
/// <summary>
/// Model context size (n_ctx)
/// </summary>
public int ContextSize { get; set; } = 512;
public uint ContextSize { get; set; } = 512;
/// <summary>
/// the GPU that is used for scratch and small tensors
/// </summary>
public int MainGpu { get; set; } = 0;
/// <summary>
/// if true, reduce VRAM usage at the cost of performance
/// </summary>
public bool LowVram { get; set; } = false;
/// <summary>
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
/// </summary>
@ -31,7 +28,7 @@ namespace LLama.Common
/// <summary>
/// Seed for the random number generator (seed)
/// </summary>
public int Seed { get; set; } = 1686349486;
public uint Seed { get; set; } = 1686349486;
/// <summary>
/// Use f16 instead of f32 for memory kv (memory_f16)
/// </summary>
@ -52,22 +49,31 @@ namespace LLama.Common
/// Model path (model)
/// </summary>
public string ModelPath { get; set; }
/// <summary>
/// lora adapter path (lora_adapter)
/// List of LoRAs to apply
/// </summary>
public string LoraAdapter { get; set; } = string.Empty;
public AdapterCollection LoraAdapters { get; set; } = new();
/// <summary>
/// base model path for the lora adapter (lora_base)
/// </summary>
public string LoraBase { get; set; } = string.Empty;
/// <summary>
/// Number of threads (-1 = autodetect) (n_threads)
/// Number of threads (null = autodetect) (n_threads)
/// </summary>
public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1);
public uint? Threads { get; set; }
/// <summary>
/// Number of threads to use for batch processing (null = autodetect) (n_threads)
/// </summary>
public uint? BatchThreads { get; set; }
/// <summary>
/// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
/// </summary>
public int BatchSize { get; set; } = 512;
public uint BatchSize { get; set; } = 512;
/// <summary>
/// Whether to use embedding mode. (embedding) Note that if this is set to true,
@ -95,6 +101,11 @@ namespace LLama.Common
/// </summary>
public bool MulMatQ { get; set; }
/// <summary>
/// Load vocab only (no weights)
/// </summary>
public bool VocabOnly { get; set; }
/// <summary>
/// The encoding to use to convert text for the model
/// </summary>
@ -138,10 +149,10 @@ namespace LLama.Common
/// <param name="mulMatQ">Use experimental mul_mat_q kernels</param>
/// <param name="encoding">The encoding to use to convert text for the model</param>
[Obsolete("Use object initializer to set all optional parameters")]
public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20,
int seed = 1337, bool useFp16Memory = true,
public ModelParams(string modelPath, uint contextSize = 512, int gpuLayerCount = 20,
uint seed = 1337, bool useFp16Memory = true,
bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false,
string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512,
string loraAdapter = "", string loraBase = "", int threads = -1, uint batchSize = 512,
bool embeddingMode = false,
float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false,
string encoding = "UTF-8")
@ -154,15 +165,15 @@ namespace LLama.Common
UseMemoryLock = useMemoryLock;
Perplexity = perplexity;
ModelPath = modelPath;
LoraAdapter = loraAdapter;
LoraBase = loraBase;
Threads = threads == -1 ? Math.Max(Environment.ProcessorCount / 2, 1) : threads;
Threads = threads < 1 ? null : (uint)threads;
BatchSize = batchSize;
EmbeddingMode = embeddingMode;
RopeFrequencyBase = ropeFrequencyBase;
RopeFrequencyScale = ropeFrequencyScale;
MulMatQ = mulMatQ;
Encoding = Encoding.GetEncoding(encoding);
LoraAdapters.Add(new LoraAdapter(loraAdapter, 1));
}
}

View File

@ -0,0 +1,46 @@
using System;
using System.IO;
using LLama.Abstractions;
using LLama.Native;
namespace LLama.Extensions
{
/// <summary>
/// Extention methods to the IContextParams interface
/// </summary>
public static class IContextParamsExtensions
{
/// <summary>
/// Convert the given `IModelParams` into a `LLamaContextParams`
/// </summary>
/// <param name="params"></param>
/// <param name="result"></param>
/// <returns></returns>
/// <exception cref="FileNotFoundException"></exception>
/// <exception cref="ArgumentException"></exception>
public static void ToLlamaContextParams(this IContextParams @params, out LLamaContextParams result)
{
result = NativeApi.llama_context_default_params();
result.n_ctx = @params.ContextSize;
result.n_batch = @params.BatchSize;
result.seed = @params.Seed;
result.f16_kv = @params.UseFp16Memory;
result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
result.rope_freq_base = @params.RopeFrequencyBase;
result.rope_freq_scale = @params.RopeFrequencyScale;
result.mul_mat_q = @params.MulMatQ;
result.n_threads = Threads(@params.Threads);
result.n_threads_batch = Threads(@params.BatchThreads);
}
private static uint Threads(uint? value)
{
if (value is > 0)
return (uint)value;
return (uint)Math.Max(Environment.ProcessorCount / 2, 1);
}
}
}

View File

@ -12,41 +12,30 @@ namespace LLama.Extensions
public static class IModelParamsExtensions
{
/// <summary>
/// Convert the given `IModelParams` into a `LLamaContextParams`
/// Convert the given `IModelParams` into a `LLamaModelParams`
/// </summary>
/// <param name="params"></param>
/// <param name="result"></param>
/// <returns></returns>
/// <exception cref="FileNotFoundException"></exception>
/// <exception cref="ArgumentException"></exception>
public static MemoryHandle ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result)
public static MemoryHandle ToLlamaModelParams(this IModelParams @params, out LLamaModelParams result)
{
if (!File.Exists(@params.ModelPath))
throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");
if (@params.TensorSplits != null && @params.TensorSplits.Length != 1)
throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp.");
result = NativeApi.llama_context_default_params();
result.n_ctx = @params.ContextSize;
result.n_batch = @params.BatchSize;
result = NativeApi.llama_model_default_params();
result.main_gpu = @params.MainGpu;
result.n_gpu_layers = @params.GpuLayerCount;
result.seed = @params.Seed;
result.f16_kv = @params.UseFp16Memory;
result.use_mmap = @params.UseMemorymap;
result.use_mlock = @params.UseMemoryLock;
result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
result.low_vram = @params.LowVram;
result.rope_freq_base = @params.RopeFrequencyBase;
result.rope_freq_scale = @params.RopeFrequencyScale;
result.mul_mat_q = @params.MulMatQ;
result.use_mmap = @params.UseMemorymap;
result.vocab_only = @params.VocabOnly;
var pin = @params.TensorSplits.AsMemory().Pin();
unsafe
{
result.tensor_split = (nint)pin.Pointer;
result.tensor_split = (float*)pin.Pointer;
}
return pin;

View File

@ -42,14 +42,9 @@ namespace LLama
public int EmbeddingSize => _ctx.EmbeddingSize;
/// <summary>
/// Get the number of tokens in the KV Cache for this context
/// The context params set for this context
/// </summary>
public int KVCacheTokenCount => _ctx.KVCacheTokenCount;
/// <summary>
/// The model params set for this model.
/// </summary>
public IModelParams Params { get; set; }
public IContextParams Params { get; set; }
/// <summary>
/// The native handle, which is used to be passed to the native APIs
@ -62,24 +57,7 @@ namespace LLama
/// </summary>
public Encoding Encoding => _encoding;
/// <summary>
///
/// </summary>
/// <param name="params">Model params.</param>
/// <param name="logger">The logger.</param>
[Obsolete("Use the LLamaWeights.CreateContext instead")]
public LLamaContext(IModelParams @params, ILogger? logger = null)
{
Params = @params;
_logger = logger;
_encoding = @params.Encoding;
_logger?.LogInformation($"[LLamaContext] Initializing LLama model with params: {this.Params}");
_ctx = Utils.InitLLamaContextFromModelParams(Params);
}
internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, ILogger? logger = null)
internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null)
{
Params = @params;
@ -95,7 +73,7 @@ namespace LLama
/// <param name="params"></param>
/// <param name="logger"></param>
/// <exception cref="ObjectDisposedException"></exception>
public LLamaContext(LLamaWeights model, IModelParams @params, ILogger? logger = null)
public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger = null)
{
if (model.NativeHandle.IsClosed)
throw new ObjectDisposedException("Cannot create context, model weights have been disposed");
@ -105,30 +83,20 @@ namespace LLama
_logger = logger;
_encoding = @params.Encoding;
using var pin = @params.ToLlamaContextParams(out var lparams);
@params.ToLlamaContextParams(out var lparams);
_ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
}
/// <summary>
/// Create a copy of the current state of this context
/// </summary>
/// <returns></returns>
public LLamaContext Clone()
{
using var pin = Params.ToLlamaContextParams(out var lparams);
var clone = _ctx.Clone(lparams);
return new LLamaContext(clone, Params);
}
/// <summary>
/// Tokenize a string.
/// </summary>
/// <param name="text"></param>
/// <param name="addBos">Whether to add a bos to the text.</param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns>
public llama_token[] Tokenize(string text, bool addBos = true)
public llama_token[] Tokenize(string text, bool addBos = true, bool special = false)
{
return _ctx.Tokenize(text, addBos, _encoding);
return _ctx.Tokenize(text, addBos, special, _encoding);
}
/// <summary>
@ -177,19 +145,6 @@ namespace LLama
fileStream.SetLength(writtenBytes);
}
/// <summary>
/// Get the state data as a byte array.
/// </summary>
/// <returns></returns>
[Obsolete("Use `GetState` instead, this supports larger states (over 2GB)")]
public byte[] GetStateData()
{
var stateSize = NativeApi.llama_get_state_size(_ctx);
byte[] stateMemory = new byte[stateSize];
NativeApi.llama_copy_state_data(_ctx, stateMemory);
return stateMemory;
}
/// <summary>
/// Get the state data as an opaque handle
/// </summary>
@ -198,31 +153,28 @@ namespace LLama
{
var stateSize = _ctx.GetStateSize();
unsafe
// Allocate a chunk of memory large enough to hold the entire state
var memory = Marshal.AllocHGlobal((nint)stateSize);
try
{
// Allocate a chunk of memory large enough to hold the entire state
var memory = Marshal.AllocHGlobal((nint)stateSize);
try
{
// Copy the state data into memory, discover the actual size required
var actualSize = _ctx.GetState(memory, stateSize);
// Copy the state data into memory, discover the actual size required
var actualSize = _ctx.GetState(memory, stateSize);
// Shrink to size
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
// Shrink to size
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
// Wrap memory in a "state"
var state = new State(memory);
// Wrap memory in a "state"
var state = new State(memory);
// Set memory to zero, to prevent it being freed in finally block
memory = IntPtr.Zero;
// Set memory to zero, to prevent it being freed in finally block
memory = IntPtr.Zero;
return state;
}
finally
{
if (memory != IntPtr.Zero)
Marshal.FreeHGlobal(memory);
}
return state;
}
finally
{
if (memory != IntPtr.Zero)
Marshal.FreeHGlobal(memory);
}
}
@ -247,21 +199,6 @@ namespace LLama
}
}
/// <summary>
/// Load the state from memory.
/// </summary>
/// <param name="stateData"></param>
/// <exception cref="RuntimeError"></exception>
public void LoadState(byte[] stateData)
{
int stateSize = (int)NativeApi.llama_get_state_size(_ctx);
if (stateData.Length > stateSize)
{
throw new RuntimeError("Failed to validate state size.");
}
NativeApi.llama_set_state_data(_ctx, stateData);
}
/// <summary>
/// Load the state from memory.
/// </summary>
@ -463,15 +400,15 @@ namespace LLama
public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount)
{
var total = tokens.Length;
for(var i = 0; i < total; i += Params.BatchSize)
for(var i = 0; i < total; i += (int)Params.BatchSize)
{
var n_eval = total - i;
if (n_eval > Params.BatchSize)
{
n_eval = Params.BatchSize;
n_eval = (int)Params.BatchSize;
}
if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads))
if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount))
{
_logger?.LogError($"[LLamaContext] Failed to eval.");
throw new RuntimeError("Failed to eval.");

View File

@ -18,19 +18,22 @@ namespace LLama
/// </summary>
public int EmbeddingSize => _ctx.EmbeddingSize;
/// <summary>
///
/// </summary>
/// <param name="params"></param>
public LLamaEmbedder(IModelParams @params)
public LLamaEmbedder(ILLamaParams allParams)
: this(allParams, allParams)
{
@params.EmbeddingMode = true;
using var weights = LLamaWeights.LoadFromFile(@params);
_ctx = weights.CreateContext(@params);
}
public LLamaEmbedder(LLamaWeights weights, IModelParams @params)
public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams)
{
using var weights = LLamaWeights.LoadFromFile(modelParams);
contextParams.EmbeddingMode = true;
_ctx = weights.CreateContext(contextParams);
}
public LLamaEmbedder(LLamaWeights weights, IContextParams @params)
{
@params.EmbeddingMode = true;
_ctx = weights.CreateContext(@params);
}

View File

@ -7,6 +7,7 @@ using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Extensions;
using LLama.Native;
namespace LLama
{
@ -20,7 +21,7 @@ namespace LLama
: ILLamaExecutor
{
private readonly LLamaWeights _weights;
private readonly IModelParams _params;
private readonly IContextParams _params;
/// <summary>
/// The context used by the executor when running the inference.
@ -32,7 +33,7 @@ namespace LLama
/// </summary>
/// <param name="weights"></param>
/// <param name="params"></param>
public StatelessExecutor(LLamaWeights weights, IModelParams @params)
public StatelessExecutor(LLamaWeights weights, IContextParams @params)
{
_weights = weights;
_params = @params;
@ -41,20 +42,6 @@ namespace LLama
Context.Dispose();
}
/// <summary>
/// Create a new stateless executor which will use the model used to create the given context
/// </summary>
/// <param name="context"></param>
[Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")]
public StatelessExecutor(LLamaContext context)
{
_weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding);
_params = context.Params;
Context = _weights.CreateContext(_params);
Context.Dispose();
}
/// <inheritdoc />
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
@ -114,15 +101,16 @@ namespace LLama
break;
// when run out of context
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433
if (n_past + tokens.Count > Context.ContextSize)
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize)
{
var n_left = n_past - inferenceParams.TokensKeep;
var n_left = n_past - inferenceParams.TokensKeep - 1;
var n_discard = n_left / 2;
n_past = Math.Max(1, inferenceParams.TokensKeep);
NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
tokens.Clear();
tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2));
n_past -= n_discard;
}
// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)

View File

@ -1,5 +1,4 @@
using System;
using System.Text;
using LLama.Abstractions;
using LLama.Extensions;
using LLama.Native;
@ -20,11 +19,6 @@ namespace LLama
/// <remarks>Be careful how you use this!</remarks>
public SafeLlamaModelHandle NativeHandle => _weights;
/// <summary>
/// Encoding to use to convert text into bytes for the model
/// </summary>
public Encoding Encoding { get; }
/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
@ -35,15 +29,24 @@ namespace LLama
/// </summary>
public int ContextSize => NativeHandle.ContextSize;
/// <summary>
/// Get the size of this model in bytes
/// </summary>
public ulong SizeInBytes => NativeHandle.SizeInBytes;
/// <summary>
/// Get the number of parameters in this model
/// </summary>
public ulong ParameterCount => NativeHandle.ParameterCount;
/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => NativeHandle.EmbeddingSize;
internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding)
internal LLamaWeights(SafeLlamaModelHandle weights)
{
_weights = weights;
Encoding = encoding;
}
/// <summary>
@ -53,13 +56,20 @@ namespace LLama
/// <returns></returns>
public static LLamaWeights LoadFromFile(IModelParams @params)
{
using var pin = @params.ToLlamaContextParams(out var lparams);
using var pin = @params.ToLlamaModelParams(out var lparams);
var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
if (!string.IsNullOrEmpty(@params.LoraAdapter))
weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
foreach (var adapter in @params.LoraAdapters)
{
if (string.IsNullOrEmpty(adapter.Path))
continue;
if (adapter.Scale <= 0)
continue;
return new LLamaWeights(weights, @params.Encoding);
weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase, @params.Threads);
}
return new LLamaWeights(weights);
}
/// <inheritdoc />
@ -73,7 +83,7 @@ namespace LLama
/// </summary>
/// <param name="params"></param>
/// <returns></returns>
public LLamaContext CreateContext(IModelParams @params)
public LLamaContext CreateContext(IContextParams @params)
{
return new LLamaContext(this, @params);
}

View File

@ -0,0 +1,106 @@
using System;
namespace LLama.Native;
using llama_token = Int32;
public sealed class LLamaBatchSafeHandle
: SafeLLamaHandleBase
{
private readonly int _embd;
public LLamaNativeBatch Batch { get; private set; }
/// <summary>
/// the token ids of the input (used when embd is NULL)
/// </summary>
public Span<llama_token> Token
{
get
{
unsafe
{
if (_embd != 0)
return new Span<int>(null, 0);
else
return new Span<int>(Batch.token, Batch.n_tokens);
}
}
}
/// <summary>
/// token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
/// </summary>
public Span<llama_token> Embed
{
get
{
unsafe
{
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float)
/// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
if (_embd != 0)
return new Span<llama_token>(Batch.embd, Batch.n_tokens * _embd);
else
return new Span<llama_token>(null, 0);
}
}
}
/// <summary>
/// the positions of the respective token in the sequence
/// </summary>
public Span<LLamaPos> Pos
{
get
{
unsafe
{
return new Span<LLamaPos>(Batch.pos, Batch.n_tokens);
}
}
}
/// <summary>
/// the sequence to which the respective token belongs
/// </summary>
public Span<LLamaSeqId> Sequence_ID
{
get
{
unsafe
{
return new Span<LLamaSeqId>(Batch.seq_id, Batch.n_tokens);
}
}
}
/// <summary>
/// if zero, the logits for the respective token will not be output
/// </summary>
public Span<byte> Logits
{
get
{
unsafe
{
return new Span<byte>(Batch.logits, Batch.n_tokens);
}
}
}
public LLamaBatchSafeHandle(int n_tokens, int embd)
: base((nint)1)
{
_embd = embd;
Batch = NativeApi.llama_batch_init(n_tokens, embd);
}
protected override bool ReleaseHandle()
{
NativeApi.llama_batch_free(Batch);
Batch = default;
SetHandle(IntPtr.Zero);
return true;
}
}

View File

@ -19,32 +19,27 @@ namespace LLama.Native
/// <summary>
/// RNG seed, -1 for random
/// </summary>
public int seed;
public uint seed;
/// <summary>
/// text context
/// </summary>
public int n_ctx;
public uint n_ctx;
/// <summary>
/// prompt processing batch size
/// </summary>
public int n_batch;
public uint n_batch;
/// <summary>
/// number of layers to store in VRAM
/// number of threads to use for generation
/// </summary>
public int n_gpu_layers;
public uint n_threads;
/// <summary>
/// the GPU that is used for scratch and small tensors
/// number of threads to use for batch processing
/// </summary>
public int main_gpu;
/// <summary>
/// how to split layers across multiple GPUs
/// </summary>
public nint tensor_split;
public uint n_threads_batch;
/// <summary>
/// ref: https://github.com/ggerganov/llama.cpp/pull/2054
@ -58,26 +53,6 @@ namespace LLama.Native
/// </summary>
public float rope_freq_scale;
/// <summary>
/// called with a progress value between 0 and 1, pass NULL to disable
/// </summary>
public IntPtr progress_callback;
/// <summary>
/// context pointer passed to the progress callback
/// </summary>
public IntPtr progress_callback_user_data;
/// <summary>
/// if true, reduce VRAM usage at the cost of performance
/// </summary>
public bool low_vram
{
readonly get => Convert.ToBoolean(_low_vram);
set => _low_vram = Convert.ToSByte(value);
}
private sbyte _low_vram;
/// <summary>
/// if true, use experimental mul_mat_q kernels
/// </summary>
@ -108,36 +83,6 @@ namespace LLama.Native
}
private sbyte _logits_all;
/// <summary>
/// only load the vocabulary, no weights
/// </summary>
public bool vocab_only
{
readonly get => Convert.ToBoolean(_vocab_only);
set => _vocab_only = Convert.ToSByte(value);
}
private sbyte _vocab_only;
/// <summary>
/// use mmap if possible
/// </summary>
public bool use_mmap
{
readonly get => Convert.ToBoolean(_use_mmap);
set => _use_mmap = Convert.ToSByte(value);
}
private sbyte _use_mmap;
/// <summary>
/// force system to keep model in RAM
/// </summary>
public bool use_mlock
{
readonly get => Convert.ToBoolean(_use_mlock);
set => _use_mlock = Convert.ToSByte(value);
}
private sbyte _use_mlock;
/// <summary>
/// embedding mode only
/// </summary>

View File

@ -0,0 +1,67 @@
using System;
using System.Runtime.InteropServices;
namespace LLama.Native
{
/// <summary>
/// A C# representation of the llama.cpp `llama_model_params` struct
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public unsafe struct LLamaModelParams
{
/// <summary>
/// // number of layers to store in VRAM
/// </summary>
public int n_gpu_layers;
/// <summary>
/// // the GPU that is used for scratch and small tensors
/// </summary>
public int main_gpu;
/// <summary>
/// how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
/// </summary>
public float* tensor_split;
/// <summary>
/// called with a progress value between 0 and 1, pass NULL to disable
/// </summary>
LlamaProgressCallback progress_callback;
/// <summary>
/// context pointer passed to the progress callback
/// </summary>
void* progress_callback_user_data;
/// <summary>
/// only load the vocabulary, no weights
/// </summary>
public bool vocab_only
{
readonly get => Convert.ToBoolean(_vocab_only);
set => _vocab_only = Convert.ToSByte(value);
}
private sbyte _vocab_only;
/// <summary>
/// use mmap if possible
/// </summary>
public bool use_mmap
{
readonly get => Convert.ToBoolean(_use_mmap);
set => _use_mmap = Convert.ToSByte(value);
}
private sbyte _use_mmap;
/// <summary>
/// force system to keep model in RAM
/// </summary>
public bool use_mlock
{
readonly get => Convert.ToBoolean(_use_mlock);
set => _use_mlock = Convert.ToSByte(value);
}
private sbyte _use_mlock;
}
}

View File

@ -36,5 +36,15 @@ namespace LLama.Native
set => _quantize_output_tensor = Convert.ToSByte(value);
}
private sbyte _quantize_output_tensor;
/// <summary>
/// only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
/// </summary>
public bool only_copy
{
get => Convert.ToBoolean(_only_copy);
set => _only_copy = Convert.ToSByte(value);
}
private sbyte _only_copy;
}
}

View File

@ -0,0 +1,45 @@
using System;
using System.Runtime.InteropServices;
namespace LLama.Native;
using llama_token = Int32;
/// <summary>
/// Input data for llama_decode
/// A llama_batch object can contain input about one or many sequences
/// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public readonly unsafe struct LLamaNativeBatch
{
/// <summary>
/// The number of items pointed at by pos, seq_id and logits.
/// </summary>
public readonly int n_tokens;
/// <summary>
/// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created
/// </summary>
public readonly llama_token* token;
/// <summary>
/// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created
/// </summary>
public readonly float* embd;
/// <summary>
/// the positions of the respective token in the sequence
/// </summary>
public readonly LLamaPos* pos;
/// <summary>
/// the sequence to which the respective token belongs
/// </summary>
public readonly LLamaSeqId* seq_id;
/// <summary>
/// if zero, the logits for the respective token will not be output
/// </summary>
public readonly byte* logits;
}

15
LLama/Native/LLamaPos.cs Normal file
View File

@ -0,0 +1,15 @@
namespace LLama.Native;
public record struct LLamaPos
{
public int Value;
public LLamaPos(int value)
{
Value = value;
}
public static explicit operator int(LLamaPos pos) => pos.Value;
public static implicit operator LLamaPos(int value) => new(value);
}

View File

@ -0,0 +1,15 @@
namespace LLama.Native;
public record struct LLamaSeqId
{
public int Value;
public LLamaSeqId(int value)
{
Value = value;
}
public static explicit operator int(LLamaSeqId pos) => pos.Value;
public static explicit operator LLamaSeqId(int value) => new(value);
}

View File

@ -2,7 +2,6 @@
using System.Buffers;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Common;
using LLama.Exceptions;
#pragma warning disable IDE1006 // Naming Styles
@ -110,6 +109,13 @@ namespace LLama.Native
[DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_empty_call();
/// <summary>
/// Create a LLamaModelParams with default values
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaModelParams llama_model_default_params();
/// <summary>
/// Create a LLamaContextParams with default values
/// </summary>
@ -138,18 +144,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_mlock_supported();
/// <summary>
/// Export a static computation graph for context of 511 and batch size of 1
/// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
/// parameters here to keep things simple
/// IMPORTANT: do not use for anything else other than debugging and testing!
/// </summary>
/// <param name="ctx"></param>
/// <param name="fname"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval_export(SafeLLamaContextHandle ctx, string fname);
/// <summary>
/// Various functions for loading a ggml llama model.
/// Allocate (almost) all memory needed for the model.
@ -159,7 +153,7 @@ namespace LLama.Native
/// <param name="params"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams @params);
public static extern IntPtr llama_load_model_from_file(string path_model, LLamaModelParams @params);
/// <summary>
/// Create a new llama_context with the given model.
@ -192,7 +186,7 @@ namespace LLama.Native
/// <param name="model"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_free_model(IntPtr model);
/// <summary>
/// Apply a LoRA adapter to a loaded model
/// path_base_model is the path to a higher quality model to use as a base for
@ -202,19 +196,12 @@ namespace LLama.Native
/// </summary>
/// <param name="model_ptr"></param>
/// <param name="path_lora"></param>
/// <param name="scale"></param>
/// <param name="path_base_model"></param>
/// <param name="n_threads"></param>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, string? path_base_model, int n_threads);
/// <summary>
/// Returns the number of tokens in the KV cache
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx);
public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads);
/// <summary>
/// Sets the current rng seed.
@ -222,7 +209,7 @@ namespace LLama.Native
/// <param name="ctx"></param>
/// <param name="seed"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, int seed);
public static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, uint seed);
/// <summary>
/// Returns the maximum size in bytes of the state (rng, logits, embedding
@ -243,21 +230,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest);
/// <summary>
/// Copies the state to the specified destination address.
/// Destination needs to have allocated enough memory (see llama_get_state_size)
/// </summary>
/// <param name="ctx"></param>
/// <param name="dest"></param>
/// <returns>the number of bytes copied</returns>
public static ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte[] dest)
{
fixed (byte* dstPtr = &dest[0])
{
return llama_copy_state_data(ctx, dstPtr);
}
}
/// <summary>
/// Set the state reading from the specified address
/// </summary>
@ -267,20 +239,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src);
/// <summary>
/// Set the state reading from the specified address
/// </summary>
/// <param name="ctx"></param>
/// <param name="src"></param>
/// <returns>the number of bytes read</returns>
public static ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte[] src)
{
fixed (byte* srcPtr = &src[0])
{
return llama_set_state_data(ctx, srcPtr);
}
}
/// <summary>
/// Load session file
/// </summary>
@ -313,24 +271,9 @@ namespace LLama.Native
/// <param name="tokens"></param>
/// <param name="n_tokens"></param>
/// <param name="n_past"></param>
/// <param name="n_threads"></param>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int n_tokens, int n_past, int n_threads);
/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
/// tokens + n_tokens is the provided batch of new tokens to process
/// n_past is the number of tokens to use from previous eval calls
/// </summary>
/// <param name="ctx"></param>
/// <param name="tokens"></param>
/// <param name="n_tokens"></param>
/// <param name="n_past"></param>
/// <param name="n_threads"></param>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName, EntryPoint = "llama_eval", CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval_with_pointer(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past, int n_threads);
public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past);
/// <summary>
/// Convert the provided text into tokens.
@ -341,10 +284,11 @@ namespace LLama.Native
/// <param name="tokens"></param>
/// <param name="n_max_tokens"></param>
/// <param name="add_bos"></param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.</param>
/// <returns>Returns the number of tokens on success, no more than n_max_tokens.
/// Returns a negative number on failure - the number of tokens that would have been returned
/// </returns>
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos)
public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special)
{
// Calculate number of bytes in text and borrow an array that large (+1 for nul byte)
var byteCount = encoding.GetByteCount(text);
@ -364,7 +308,7 @@ namespace LLama.Native
// Do the actual tokenization
fixed (byte* arrayPtr = array)
fixed (llama_token* tokensPtr = tokens)
return llama_tokenize_native(ctx, arrayPtr, tokensPtr, n_max_tokens, add_bos);
return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special);
}
finally
{
@ -372,28 +316,6 @@ namespace LLama.Native
}
}
/// <summary>
/// Convert the provided text into tokens.
/// </summary>
/// <param name="ctx"></param>
/// <param name="text"></param>
/// <param name="tokens"></param>
/// <param name="n_max_tokens"></param>
/// <param name="add_bos"></param>
/// <returns>Returns the number of tokens on success, no more than n_max_tokens.
/// Returns a negative number on failure - the number of tokens that would have been returned
/// </returns>
[DllImport(libraryName, EntryPoint = "llama_tokenize", CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, byte* text, llama_token* tokens, int n_max_tokens, bool add_bos);
/// <summary>
/// Get the number of tokens in the model vocabulary for this context
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_vocab(SafeLLamaContextHandle ctx);
/// <summary>
/// Get the size of the context window for the model for this context
/// </summary>
@ -402,14 +324,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_ctx(SafeLLamaContextHandle ctx);
/// <summary>
/// Get the dimension of embedding vectors from the model for this context
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_embd(SafeLLamaContextHandle ctx);
/// <summary>
/// Token logits obtained from the last call to llama_eval()
/// The logits for the last token are stored in the last row
@ -422,6 +336,14 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float* llama_get_logits(SafeLLamaContextHandle ctx);
/// <summary>
/// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx);
/// <summary>
/// Get the embeddings for the input
/// shape: [n_embd] (1-dimensional)
@ -431,15 +353,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx);
/// <summary>
/// Token Id -> String. Uses the vocabulary in the provided context
/// </summary>
/// <param name="ctx"></param>
/// <param name="token"></param>
/// <returns>Pointer to a string.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_token_to_str(SafeLLamaContextHandle ctx, llama_token token);
/// <summary>
/// Get the "Beginning of sentence" token
/// </summary>
@ -488,7 +401,7 @@ namespace LLama.Native
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_n_vocab(SafeLlamaModelHandle model);
public static extern int llama_n_vocab(SafeLlamaModelHandle model);
/// <summary>
/// Get the size of the context window for the model
@ -496,7 +409,7 @@ namespace LLama.Native
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_n_ctx(SafeLlamaModelHandle model);
public static extern int llama_n_ctx_train(SafeLlamaModelHandle model);
/// <summary>
/// Get the dimension of embedding vectors from this model
@ -504,7 +417,23 @@ namespace LLama.Native
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_model_n_embd(SafeLlamaModelHandle model);
public static extern int llama_n_embd(SafeLlamaModelHandle model);
/// <summary>
/// Get the size of the model in bytes
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern ulong llama_model_size(SafeLlamaModelHandle model);
/// <summary>
/// Get the number of parameters in this model
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern ulong llama_model_n_params(SafeLlamaModelHandle model);
/// <summary>
/// Convert a single token into text
@ -515,21 +444,23 @@ namespace LLama.Native
/// <param name="length">size of the buffer</param>
/// <returns>The length writte, or if the buffer is too small a negative that indicates the length required</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_token_to_piece_with_model(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);
public static extern int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);
/// <summary>
/// Convert text into tokens
/// </summary>
/// <param name="model"></param>
/// <param name="text"></param>
/// <param name="text_len"></param>
/// <param name="tokens"></param>
/// <param name="n_max_tokens"></param>
/// <param name="add_bos"></param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.</param>
/// <returns>Returns the number of tokens on success, no more than n_max_tokens.
/// Returns a negative number on failure - the number of tokens that would have been returned
/// </returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos);
public static extern int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special);
/// <summary>
/// Register a callback to receive llama log messages
@ -537,5 +468,98 @@ namespace LLama.Native
/// <param name="logCallback"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_log_set(LLamaLogCallback logCallback);
}
/// <summary>
/// Remove all tokens data of cells in [c0, c1)
/// </summary>
/// <param name="ctx"></param>
/// <param name="c0"></param>
/// <param name="c1"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_tokens_rm(SafeLLamaContextHandle ctx, int c0, int c1);
/// <summary>
/// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
/// </summary>
/// <param name="ctx"></param>
/// <param name="seq"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1);
/// <summary>
/// Copy all tokens that belong to the specified sequence to another sequence
/// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
/// </summary>
/// <param name="ctx"></param>
/// <param name="src"></param>
/// <param name="dest"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_cp(SafeLLamaContextHandle ctx, LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1);
/// <summary>
/// Removes all tokens that do not belong to the specified sequence
/// </summary>
/// <param name="ctx"></param>
/// <param name="seq"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_keep(SafeLLamaContextHandle ctx, LLamaSeqId seq);
/// <summary>
/// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
/// If the KV cache is RoPEd, the KV data is updated accordingly
/// </summary>
/// <param name="ctx"></param>
/// <param name="seq"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
/// <param name="delta"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta);
/// <summary>
/// Allocates a batch of tokens on the heap
/// The batch has to be freed with llama_batch_free()
/// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
/// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
/// The rest of the llama_batch members are allocated with size n_tokens
/// All members are left uninitialized
/// </summary>
/// <param name="n_tokens"></param>
/// <param name="embd"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaNativeBatch llama_batch_init(int n_tokens, int embd);
/// <summary>
/// Frees a batch of tokens allocated with llama_batch_init()
/// </summary>
/// <param name="batch"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_batch_free(LLamaNativeBatch batch);
/// <summary>
/// </summary>
/// <param name="ctx"></param>
/// <param name="batch"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch);
/// <summary>
/// Set the number of threads used for decoding
/// </summary>
/// <param name="ctx"></param>
/// <param name="n_threads">n_threads is the number of threads used for generation (single token)</param>
/// <param name="n_threads_batch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);
}
}

View File

@ -1,5 +1,6 @@
using System;
using System.Buffers;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
@ -21,26 +22,13 @@ namespace LLama.Native
/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize => ThrowIfDisposed().ContextSize;
public int ContextSize => NativeApi.llama_n_ctx(this);
/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize;
/// <summary>
/// Get the number of tokens in the KV Cache for this context
/// </summary>
public int KVCacheTokenCount
{
get
{
if (IsClosed)
throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed");
return NativeApi.llama_get_kv_cache_token_count(this);
}
}
/// <summary>
/// Get the model which this context is using
/// </summary>
@ -64,17 +52,20 @@ namespace LLama.Native
_model.DangerousAddRef(ref success);
if (!success)
throw new RuntimeError("Failed to increment model refcount");
}
/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_free(DangerousGetHandle());
SetHandle(IntPtr.Zero);
// Decrement refcount on model
_model?.DangerousRelease();
_model = null!;
NativeApi.llama_free(handle);
SetHandle(IntPtr.Zero);
return true;
}
@ -103,80 +94,8 @@ namespace LLama.Native
return new(ctx_ptr, model);
}
/// <summary>
/// Create a new llama context with a clone of the current llama context state
/// </summary>
/// <param name="lparams"></param>
/// <returns></returns>
public SafeLLamaContextHandle Clone(LLamaContextParams lparams)
{
// Allocate space to read the state of the current context
var stateSize = GetStateSize();
var stateMemory = Marshal.AllocHGlobal((nint)stateSize);
try
{
// Copy state from this context into memory
GetState(stateMemory, stateSize);
// Create a new context
var newCtx = Create(ModelHandle, lparams);
// Copy state into new context
newCtx.SetState(stateMemory);
return newCtx;
}
finally
{
Marshal.FreeHGlobal(stateMemory);
}
}
#endregion
/// <summary>
/// Convert the given text into tokens
/// </summary>
/// <param name="text">The text to tokenize</param>
/// <param name="add_bos">Whether the "BOS" token should be added</param>
/// <param name="encoding">Encoding to use for the text</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
{
ThrowIfDisposed();
if (string.IsNullOrEmpty(text) && !add_bos)
return Array.Empty<int>();
// Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't
// possibly be more than this.
var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0);
// "Rent" an array to write results into (avoiding an allocation of a large array)
var temporaryArray = ArrayPool<int>.Shared.Rent(count);
try
{
// Do the actual conversion
var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos);
if (n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
"specify the encoding.");
}
// Copy the results from the rented into an array which is exactly the right size
var result = new int[n];
Array.ConstrainedCopy(temporaryArray, 0, result, 0, n);
return result;
}
finally
{
ArrayPool<int>.Shared.Return(temporaryArray);
}
}
/// <summary>
/// Token logits obtained from the last call to llama_eval()
/// The logits for the last token are stored in the last row
@ -196,6 +115,51 @@ namespace LLama.Native
}
}
#region tokens
/// <summary>
/// Convert the given text into tokens
/// </summary>
/// <param name="text">The text to tokenize</param>
/// <param name="add_bos">Whether the "BOS" token should be added</param>
/// <param name="encoding">Encoding to use for the text</param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{
ThrowIfDisposed();
if (string.IsNullOrEmpty(text) && !add_bos)
return Array.Empty<int>();
// Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't
// possibly be more than this.
var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0);
// "Rent" an array to write results into (avoiding an allocation of a large array)
var temporaryArray = ArrayPool<int>.Shared.Rent(count);
try
{
// Do the actual conversion
var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos, special);
if (n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
"specify the encoding.");
}
// Copy the results from the rented into an array which is exactly the right size
var result = new int[n];
Array.ConstrainedCopy(temporaryArray, 0, result, 0, n);
return result;
}
finally
{
ArrayPool<int>.Shared.Return(temporaryArray);
}
}
/// <summary>
/// Convert a token into a string
/// </summary>
@ -228,25 +192,31 @@ namespace LLama.Native
{
return ThrowIfDisposed().TokenToSpan(token, dest);
}
#endregion
/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
/// </summary>
/// <param name="tokens">The provided batch of new tokens to process</param>
/// <param name="n_past">the number of tokens to use from previous eval calls</param>
/// <param name="n_threads"></param>
/// <returns>Returns true on success</returns>
public bool Eval(ReadOnlySpan<int> tokens, int n_past, int n_threads)
public bool Eval(ReadOnlySpan<int> tokens, int n_past)
{
unsafe
{
fixed (int* pinned = tokens)
{
return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0;
var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past);
return ret == 0;
}
}
}
public int Decode(LLamaBatchSafeHandle batch)
{
return NativeApi.llama_decode(this, batch.Batch);
}
#region state
/// <summary>
/// Get the size of the state, when saved as bytes

View File

@ -29,18 +29,30 @@ namespace LLama.Native
/// </summary>
public int EmbeddingSize { get; }
/// <summary>
/// Get the size of this model in bytes
/// </summary>
public ulong SizeInBytes { get; }
/// <summary>
/// Get the number of parameters in this model
/// </summary>
public ulong ParameterCount { get; }
internal SafeLlamaModelHandle(IntPtr handle)
: base(handle)
{
VocabCount = NativeApi.llama_model_n_vocab(this);
ContextSize = NativeApi.llama_model_n_ctx(this);
EmbeddingSize = NativeApi.llama_model_n_embd(this);
VocabCount = NativeApi.llama_n_vocab(this);
ContextSize = NativeApi.llama_n_ctx_train(this);
EmbeddingSize = NativeApi.llama_n_embd(this);
SizeInBytes = NativeApi.llama_model_size(this);
ParameterCount = NativeApi.llama_model_n_params(this);
}
/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_free_model(handle);
NativeApi.llama_free_model(DangerousGetHandle());
SetHandle(IntPtr.Zero);
return true;
}
@ -52,7 +64,7 @@ namespace LLama.Native
/// <param name="lparams"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams)
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams)
{
var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
if (model_ptr == IntPtr.Zero)
@ -62,21 +74,24 @@ namespace LLama.Native
}
#region LoRA
/// <summary>
/// Apply a LoRA adapter to a loaded model
/// </summary>
/// <param name="lora"></param>
/// <param name="scale"></param>
/// <param name="modelBase">A path to a higher quality model to use as a base for the layers modified by the
/// adapter. Can be NULL to use the current loaded model.</param>
/// <param name="threads"></param>
/// <exception cref="RuntimeError"></exception>
public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1)
public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, uint? threads = null)
{
var err = NativeApi.llama_model_apply_lora_from_file(
this,
lora,
scale,
string.IsNullOrEmpty(modelBase) ? null : modelBase,
threads
(int?)threads ?? -1
);
if (err != 0)
@ -97,7 +112,7 @@ namespace LLama.Native
{
fixed (byte* destPtr = dest)
{
var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, destPtr, dest.Length);
var length = NativeApi.llama_token_to_piece(this, llama_token, destPtr, dest.Length);
return Math.Abs(length);
}
}
@ -113,7 +128,7 @@ namespace LLama.Native
{
unsafe
{
var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0);
var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0);
if (length == 0)
return "";
@ -121,7 +136,7 @@ namespace LLama.Native
fixed (byte* bytePtr = bytes)
{
var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length);
var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length);
Debug.Assert(written == bytes.Length);
return encoding.GetString(bytePtr, bytes.Length);
@ -139,7 +154,7 @@ namespace LLama.Native
{
unsafe
{
var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0);
var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0);
if (length == 0)
return;
@ -147,7 +162,7 @@ namespace LLama.Native
fixed (byte* bytePtr = bytes)
{
// Decode into bytes
var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length);
var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length);
Debug.Assert(written == bytes.Length);
// Decode into chars
@ -256,8 +271,9 @@ namespace LLama.Native
/// <param name="text"></param>
/// <param name="add_bos"></param>
/// <param name="encoding"></param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns>
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
{
// Convert string to bytes, adding one extra byte to the end (null terminator)
var bytesCount = encoding.GetByteCount(text);
@ -276,13 +292,13 @@ namespace LLama.Native
fixed (byte* bytesPtr = &bytes[0])
{
// Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space)
var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos);
var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (int*)IntPtr.Zero, 0, add_bos, special);
// Tokenize again, this time outputting into an array of exactly the right size
var tokens = new int[count];
fixed (int* tokensPtr = &tokens[0])
{
NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos);
NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special);
return tokens;
}
}

View File

@ -1,108 +0,0 @@
using LLama.Abstractions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Extensions;
namespace LLama
{
using llama_token = Int32;
/// <summary>
/// Assorted llama utilities
/// </summary>
public static class Utils
{
[Obsolete("Use LLamaWeights.LoadFromFile and LLamaWeights.CreateContext instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
using var weights = LLamaWeights.LoadFromFile(@params);
using (@params.ToLlamaContextParams(out var lparams))
return SafeLLamaContextHandle.Create(weights.NativeHandle, lparams);
}
[Obsolete("Use SafeLLamaContextHandle Tokenize method instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
return ctx.Tokenize(text, add_bos, encoding);
}
[Obsolete("Use SafeLLamaContextHandle GetLogits method instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static Span<float> GetLogits(SafeLLamaContextHandle ctx, int length)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
if (length != ctx.VocabCount)
throw new ArgumentException("length must be the VocabSize");
return ctx.GetLogits();
}
[Obsolete("Use SafeLLamaContextHandle Eval method instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
var slice = tokens.AsSpan().Slice(startIndex, n_tokens);
return ctx.Eval(slice, n_past, n_threads) ? 0 : 1;
}
[Obsolete("Use SafeLLamaContextHandle TokenToString method instead")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
return ctx.TokenToString(token, encoding);
}
[Obsolete("No longer used internally by LlamaSharp")]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public static string PtrToString(IntPtr ptr, Encoding encoding)
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
{
#if NET6_0_OR_GREATER
// ReSharper disable once PossibleUnintendedReferenceComparison
if(encoding == Encoding.UTF8)
{
return Marshal.PtrToStringUTF8(ptr)!;
}
// ReSharper disable once PossibleUnintendedReferenceComparison
else if(encoding == Encoding.Unicode)
{
return Marshal.PtrToStringUni(ptr)!;
}
else
{
return Marshal.PtrToStringAuto(ptr)!;
}
#else
unsafe
{
byte* tp = (byte*)ptr.ToPointer();
List<byte> bytes = new();
while (true)
{
byte c = *tp++;
if (c == '\0')
{
break;
}
else
{
bytes.Add(c);
}
}
return encoding.GetString(bytes.ToArray());
}
#endif
}
}
}

File diff suppressed because it is too large Load Diff

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.