diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index 2cd1806f..2c06dd47 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -1,16 +1,20 @@ using System.Text; using LLama.Common; +using LLama.Native; +using Xunit.Abstractions; namespace LLama.Unittest { public sealed class BasicTest : IDisposable { + private readonly ITestOutputHelper _testOutputHelper; private readonly ModelParams _params; private readonly LLamaWeights _model; - public BasicTest() + public BasicTest(ITestOutputHelper testOutputHelper) { + _testOutputHelper = testOutputHelper; _params = new ModelParams(Constants.ModelPath) { ContextSize = 2048 @@ -30,5 +34,57 @@ namespace LLama.Unittest Assert.Equal(4096, _model.ContextSize); Assert.Equal(4096, _model.EmbeddingSize); } + + [Fact] + public void AdvancedModelProperties() + { + var expected = new Dictionary + { + { "general.name", "LLaMA v2" }, + { "general.architecture", "llama" }, + { "general.quantization_version", "2" }, + { "general.file_type", "2" }, + + { "llama.context_length", "4096" }, + { "llama.rope.dimension_count", "128" }, + { "llama.embedding_length", "4096" }, + { "llama.block_count", "32" }, + { "llama.feed_forward_length", "11008" }, + { "llama.attention.head_count", "32" }, + { "llama.attention.head_count_kv", "32" }, + { "llama.attention.layer_norm_rms_epsilon", "0.000001" }, + + { "tokenizer.ggml.eos_token_id", "2" }, + { "tokenizer.ggml.model", "llama" }, + { "tokenizer.ggml.bos_token_id", "1" }, + { "tokenizer.ggml.unknown_token_id", "0" }, + }; + + var metaCount = NativeApi.llama_model_meta_count(_model.NativeHandle); + Assert.Equal(expected.Count, metaCount); + + Span buffer = stackalloc byte[128]; + for (var i = 0; i < expected.Count; i++) + { + unsafe + { + fixed (byte* ptr = buffer) + { + var length = NativeApi.llama_model_meta_key_by_index(_model.NativeHandle, i, ptr, 128); + Assert.True(length > 0); + var key = Encoding.UTF8.GetString(buffer[..length]); + + length = NativeApi.llama_model_meta_val_str_by_index(_model.NativeHandle, i, ptr, 128); + Assert.True(length > 0); + var val = Encoding.UTF8.GetString(buffer[..length]); + + _testOutputHelper.WriteLine($"{key} == {val}"); + + Assert.True(expected.ContainsKey(key)); + Assert.Equal(expected[key], val); + } + } + } + } } } \ No newline at end of file diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 0532244d..8effd951 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -8,6 +8,8 @@ enable false + + true diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 5c48d97d..a4f97a00 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -302,6 +302,20 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern llama_token llama_token_nl(SafeLlamaModelHandle model); + /// + /// Returns -1 if unknown, 1 for true or 0 for false. + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_add_bos_token(SafeLlamaModelHandle model); + + /// + /// Returns -1 if unknown, 1 for true or 0 for false. + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_add_eos_token(SafeLlamaModelHandle model); + /// /// Print out timing information for this context /// @@ -362,9 +376,9 @@ namespace LLama.Native /// /// /// - /// + /// The length of the string on success, or -1 on failure [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_meta_val_str(SafeLlamaModelHandle model, char* key, char* buf, long buf_size); + public static extern int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); /// /// Get the number of metadata key/value pairs @@ -378,23 +392,23 @@ namespace LLama.Native /// Get metadata key name by index /// /// - /// + /// /// /// - /// + /// The length of the string on success, or -1 on failure [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int i, char* buf, long buf_size); + public static extern int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); /// /// Get metadata value as a string by index /// /// - /// + /// /// /// - /// The functions return the length of the string on success, or -1 on failure + /// The length of the string on success, or -1 on failure [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int i, char* buf, long buf_size); + public static extern int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); /// /// Get a string describing the model type @@ -402,15 +416,15 @@ namespace LLama.Native /// /// /// - /// The functions return the length of the string on success, or -1 on failure + /// The length of the string on success, or -1 on failure [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_desc(SafeLlamaModelHandle model, char* buf, long buf_size); + public static extern int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size); /// /// Get the size of the model in bytes /// /// - /// The functions return the length of the string on success, or -1 on failure + /// The size of the model [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern ulong llama_model_size(SafeLlamaModelHandle model);