Added new symbols from llama.h
This commit is contained in:
parent
37466956c7
commit
77003d763e
|
@ -1,16 +1,20 @@
|
||||||
using System.Text;
|
using System.Text;
|
||||||
using LLama.Common;
|
using LLama.Common;
|
||||||
|
using LLama.Native;
|
||||||
|
using Xunit.Abstractions;
|
||||||
|
|
||||||
namespace LLama.Unittest
|
namespace LLama.Unittest
|
||||||
{
|
{
|
||||||
public sealed class BasicTest
|
public sealed class BasicTest
|
||||||
: IDisposable
|
: IDisposable
|
||||||
{
|
{
|
||||||
|
private readonly ITestOutputHelper _testOutputHelper;
|
||||||
private readonly ModelParams _params;
|
private readonly ModelParams _params;
|
||||||
private readonly LLamaWeights _model;
|
private readonly LLamaWeights _model;
|
||||||
|
|
||||||
public BasicTest()
|
public BasicTest(ITestOutputHelper testOutputHelper)
|
||||||
{
|
{
|
||||||
|
_testOutputHelper = testOutputHelper;
|
||||||
_params = new ModelParams(Constants.ModelPath)
|
_params = new ModelParams(Constants.ModelPath)
|
||||||
{
|
{
|
||||||
ContextSize = 2048
|
ContextSize = 2048
|
||||||
|
@ -30,5 +34,57 @@ namespace LLama.Unittest
|
||||||
Assert.Equal(4096, _model.ContextSize);
|
Assert.Equal(4096, _model.ContextSize);
|
||||||
Assert.Equal(4096, _model.EmbeddingSize);
|
Assert.Equal(4096, _model.EmbeddingSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void AdvancedModelProperties()
|
||||||
|
{
|
||||||
|
var expected = new Dictionary<string, string>
|
||||||
|
{
|
||||||
|
{ "general.name", "LLaMA v2" },
|
||||||
|
{ "general.architecture", "llama" },
|
||||||
|
{ "general.quantization_version", "2" },
|
||||||
|
{ "general.file_type", "2" },
|
||||||
|
|
||||||
|
{ "llama.context_length", "4096" },
|
||||||
|
{ "llama.rope.dimension_count", "128" },
|
||||||
|
{ "llama.embedding_length", "4096" },
|
||||||
|
{ "llama.block_count", "32" },
|
||||||
|
{ "llama.feed_forward_length", "11008" },
|
||||||
|
{ "llama.attention.head_count", "32" },
|
||||||
|
{ "llama.attention.head_count_kv", "32" },
|
||||||
|
{ "llama.attention.layer_norm_rms_epsilon", "0.000001" },
|
||||||
|
|
||||||
|
{ "tokenizer.ggml.eos_token_id", "2" },
|
||||||
|
{ "tokenizer.ggml.model", "llama" },
|
||||||
|
{ "tokenizer.ggml.bos_token_id", "1" },
|
||||||
|
{ "tokenizer.ggml.unknown_token_id", "0" },
|
||||||
|
};
|
||||||
|
|
||||||
|
var metaCount = NativeApi.llama_model_meta_count(_model.NativeHandle);
|
||||||
|
Assert.Equal(expected.Count, metaCount);
|
||||||
|
|
||||||
|
Span<byte> buffer = stackalloc byte[128];
|
||||||
|
for (var i = 0; i < expected.Count; i++)
|
||||||
|
{
|
||||||
|
unsafe
|
||||||
|
{
|
||||||
|
fixed (byte* ptr = buffer)
|
||||||
|
{
|
||||||
|
var length = NativeApi.llama_model_meta_key_by_index(_model.NativeHandle, i, ptr, 128);
|
||||||
|
Assert.True(length > 0);
|
||||||
|
var key = Encoding.UTF8.GetString(buffer[..length]);
|
||||||
|
|
||||||
|
length = NativeApi.llama_model_meta_val_str_by_index(_model.NativeHandle, i, ptr, 128);
|
||||||
|
Assert.True(length > 0);
|
||||||
|
var val = Encoding.UTF8.GetString(buffer[..length]);
|
||||||
|
|
||||||
|
_testOutputHelper.WriteLine($"{key} == {val}");
|
||||||
|
|
||||||
|
Assert.True(expected.ContainsKey(key));
|
||||||
|
Assert.Equal(expected[key], val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -8,6 +8,8 @@
|
||||||
<Nullable>enable</Nullable>
|
<Nullable>enable</Nullable>
|
||||||
|
|
||||||
<IsPackable>false</IsPackable>
|
<IsPackable>false</IsPackable>
|
||||||
|
|
||||||
|
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
|
||||||
</PropertyGroup>
|
</PropertyGroup>
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
|
|
@ -302,6 +302,20 @@ namespace LLama.Native
|
||||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||||
public static extern llama_token llama_token_nl(SafeLlamaModelHandle model);
|
public static extern llama_token llama_token_nl(SafeLlamaModelHandle model);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Returns -1 if unknown, 1 for true or 0 for false.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns></returns>
|
||||||
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||||
|
public static extern int llama_add_bos_token(SafeLlamaModelHandle model);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Returns -1 if unknown, 1 for true or 0 for false.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns></returns>
|
||||||
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||||
|
public static extern int llama_add_eos_token(SafeLlamaModelHandle model);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Print out timing information for this context
|
/// Print out timing information for this context
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
@ -362,9 +376,9 @@ namespace LLama.Native
|
||||||
/// <param name="key"></param>
|
/// <param name="key"></param>
|
||||||
/// <param name="buf"></param>
|
/// <param name="buf"></param>
|
||||||
/// <param name="buf_size"></param>
|
/// <param name="buf_size"></param>
|
||||||
/// <returns></returns>
|
/// <returns>The length of the string on success, or -1 on failure</returns>
|
||||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||||
public static extern int llama_model_meta_val_str(SafeLlamaModelHandle model, char* key, char* buf, long buf_size);
|
public static extern int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the number of metadata key/value pairs
|
/// Get the number of metadata key/value pairs
|
||||||
|
@ -378,23 +392,23 @@ namespace LLama.Native
|
||||||
/// Get metadata key name by index
|
/// Get metadata key name by index
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="model"></param>
|
/// <param name="model"></param>
|
||||||
/// <param name="i"></param>
|
/// <param name="index"></param>
|
||||||
/// <param name="buf"></param>
|
/// <param name="buf"></param>
|
||||||
/// <param name="buf_size"></param>
|
/// <param name="buf_size"></param>
|
||||||
/// <returns></returns>
|
/// <returns>The length of the string on success, or -1 on failure</returns>
|
||||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||||
public static extern int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int i, char* buf, long buf_size);
|
public static extern int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get metadata value as a string by index
|
/// Get metadata value as a string by index
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="model"></param>
|
/// <param name="model"></param>
|
||||||
/// <param name="i"></param>
|
/// <param name="index"></param>
|
||||||
/// <param name="buf"></param>
|
/// <param name="buf"></param>
|
||||||
/// <param name="buf_size"></param>
|
/// <param name="buf_size"></param>
|
||||||
/// <returns>The functions return the length of the string on success, or -1 on failure</returns>
|
/// <returns>The length of the string on success, or -1 on failure</returns>
|
||||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||||
public static extern int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int i, char* buf, long buf_size);
|
public static extern int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get a string describing the model type
|
/// Get a string describing the model type
|
||||||
|
@ -402,15 +416,15 @@ namespace LLama.Native
|
||||||
/// <param name="model"></param>
|
/// <param name="model"></param>
|
||||||
/// <param name="buf"></param>
|
/// <param name="buf"></param>
|
||||||
/// <param name="buf_size"></param>
|
/// <param name="buf_size"></param>
|
||||||
/// <returns>The functions return the length of the string on success, or -1 on failure</returns>
|
/// <returns>The length of the string on success, or -1 on failure</returns>
|
||||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||||
public static extern int llama_model_desc(SafeLlamaModelHandle model, char* buf, long buf_size);
|
public static extern int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Get the size of the model in bytes
|
/// Get the size of the model in bytes
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="model"></param>
|
/// <param name="model"></param>
|
||||||
/// <returns>The functions return the length of the string on success, or -1 on failure</returns>
|
/// <returns>The size of the model</returns>
|
||||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||||
public static extern ulong llama_model_size(SafeLlamaModelHandle model);
|
public static extern ulong llama_model_size(SafeLlamaModelHandle model);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue