Added metadata overrides to `IModelParams`
This commit is contained in:
parent
b22d8b7495
commit
b868b056f7
|
@ -1,5 +1,6 @@
|
||||||
using System.Diagnostics;
|
using System.Diagnostics;
|
||||||
using System.Text;
|
using System.Text;
|
||||||
|
using LLama.Abstractions;
|
||||||
using LLama.Common;
|
using LLama.Common;
|
||||||
using LLama.Native;
|
using LLama.Native;
|
||||||
|
|
||||||
|
@ -30,6 +31,7 @@ public class BatchedDecoding
|
||||||
|
|
||||||
// Load model
|
// Load model
|
||||||
var parameters = new ModelParams(modelPath);
|
var parameters = new ModelParams(modelPath);
|
||||||
|
|
||||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||||
|
|
||||||
// Tokenize prompt
|
// Tokenize prompt
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
<Import Project="..\LLama\LLamaSharp.Runtime.targets" />
|
<Import Project="..\LLama\LLamaSharp.Runtime.targets" />
|
||||||
<PropertyGroup>
|
<PropertyGroup>
|
||||||
<OutputType>Exe</OutputType>
|
<OutputType>Exe</OutputType>
|
||||||
<TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks>
|
<TargetFrameworks>net6.0;net8.0</TargetFrameworks>
|
||||||
<ImplicitUsings>enable</ImplicitUsings>
|
<ImplicitUsings>enable</ImplicitUsings>
|
||||||
<Nullable>enable</Nullable>
|
<Nullable>enable</Nullable>
|
||||||
<Platforms>AnyCPU;x64</Platforms>
|
<Platforms>AnyCPU;x64</Platforms>
|
||||||
|
|
|
@ -10,8 +10,7 @@ Console.WriteLine("=============================================================
|
||||||
NativeLibraryConfig
|
NativeLibraryConfig
|
||||||
.Instance
|
.Instance
|
||||||
.WithCuda()
|
.WithCuda()
|
||||||
.WithLogs()
|
.WithLogs();
|
||||||
.WithAvx(NativeLibraryConfig.AvxLevel.Avx512);
|
|
||||||
|
|
||||||
NativeApi.llama_empty_call();
|
NativeApi.llama_empty_call();
|
||||||
Console.WriteLine();
|
Console.WriteLine();
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
using LLama.Common;
|
using LLama.Common;
|
||||||
using System.Text.Json;
|
using System.Text.Json;
|
||||||
|
using LLama.Abstractions;
|
||||||
|
|
||||||
namespace LLama.Unittest
|
namespace LLama.Unittest
|
||||||
{
|
{
|
||||||
|
@ -14,7 +15,12 @@ namespace LLama.Unittest
|
||||||
ContextSize = 42,
|
ContextSize = 42,
|
||||||
Seed = 42,
|
Seed = 42,
|
||||||
GpuLayerCount = 111,
|
GpuLayerCount = 111,
|
||||||
TensorSplits = { [0] = 3 }
|
TensorSplits = { [0] = 3 },
|
||||||
|
MetadataOverrides =
|
||||||
|
{
|
||||||
|
MetadataOverride.Create("hello", true),
|
||||||
|
MetadataOverride.Create("world", 17),
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
var json = JsonSerializer.Serialize(expected);
|
var json = JsonSerializer.Serialize(expected);
|
||||||
|
|
|
@ -59,6 +59,9 @@ namespace LLama.Web.Common
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public TensorSplitsCollection TensorSplits { get; set; } = new();
|
public TensorSplitsCollection TensorSplits { get; set; } = new();
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public List<MetadataOverride> MetadataOverrides { get; } = new();
|
||||||
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public float? RopeFrequencyBase { get; set; }
|
public float? RopeFrequencyBase { get; set; }
|
||||||
|
|
||||||
|
|
|
@ -59,6 +59,11 @@ namespace LLama.Abstractions
|
||||||
/// base model path for the lora adapter (lora_base)
|
/// base model path for the lora adapter (lora_base)
|
||||||
/// </summary>
|
/// </summary>
|
||||||
string LoraBase { get; set; }
|
string LoraBase { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Override specific metadata items in the model
|
||||||
|
/// </summary>
|
||||||
|
List<MetadataOverride> MetadataOverrides { get; }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
@ -186,7 +191,7 @@ namespace LLama.Abstractions
|
||||||
: JsonConverter<TensorSplitsCollection>
|
: JsonConverter<TensorSplitsCollection>
|
||||||
{
|
{
|
||||||
/// <inheritdoc/>
|
/// <inheritdoc/>
|
||||||
public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
|
public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
|
||||||
{
|
{
|
||||||
var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
|
var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
|
||||||
return new TensorSplitsCollection(arr);
|
return new TensorSplitsCollection(arr);
|
||||||
|
@ -198,4 +203,97 @@ namespace LLama.Abstractions
|
||||||
JsonSerializer.Serialize(writer, value.Splits, options);
|
JsonSerializer.Serialize(writer, value.Splits, options);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// An override for a single key/value pair in model metadata
|
||||||
|
/// </summary>
|
||||||
|
[JsonConverter(typeof(MetadataOverrideConverter))]
|
||||||
|
public abstract record MetadataOverride
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Create a new override for an int key
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="key"></param>
|
||||||
|
/// <param name="value"></param>
|
||||||
|
/// <returns></returns>
|
||||||
|
public static MetadataOverride Create(string key, int value)
|
||||||
|
{
|
||||||
|
return new IntOverride(key, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Create a new override for a float key
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="key"></param>
|
||||||
|
/// <param name="value"></param>
|
||||||
|
/// <returns></returns>
|
||||||
|
public static MetadataOverride Create(string key, float value)
|
||||||
|
{
|
||||||
|
return new FloatOverride(key, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Create a new override for a boolean key
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="key"></param>
|
||||||
|
/// <param name="value"></param>
|
||||||
|
/// <returns></returns>
|
||||||
|
public static MetadataOverride Create(string key, bool value)
|
||||||
|
{
|
||||||
|
return new BoolOverride(key, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
internal abstract void Write(ref LLamaModelMetadataOverride dest);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get the key being overriden by this override
|
||||||
|
/// </summary>
|
||||||
|
public abstract string Key { get; init; }
|
||||||
|
|
||||||
|
private record IntOverride(string Key, int Value) : MetadataOverride
|
||||||
|
{
|
||||||
|
internal override void Write(ref LLamaModelMetadataOverride dest)
|
||||||
|
{
|
||||||
|
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
|
||||||
|
dest.IntValue = Value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private record FloatOverride(string Key, float Value) : MetadataOverride
|
||||||
|
{
|
||||||
|
internal override void Write(ref LLamaModelMetadataOverride dest)
|
||||||
|
{
|
||||||
|
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
|
||||||
|
dest.FloatValue = Value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private record BoolOverride(string Key, bool Value) : MetadataOverride
|
||||||
|
{
|
||||||
|
internal override void Write(ref LLamaModelMetadataOverride dest)
|
||||||
|
{
|
||||||
|
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
|
||||||
|
dest.BoolValue = Value ? -1 : 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class MetadataOverrideConverter
|
||||||
|
: JsonConverter<MetadataOverride>
|
||||||
|
{
|
||||||
|
/// <inheritdoc/>
|
||||||
|
public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
|
||||||
|
{
|
||||||
|
throw new NotImplementedException();
|
||||||
|
//var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
|
||||||
|
//return new TensorSplitsCollection(arr);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <inheritdoc/>
|
||||||
|
public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
|
||||||
|
{
|
||||||
|
throw new NotImplementedException();
|
||||||
|
//JsonSerializer.Serialize(writer, value.Splits, options);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -1,9 +1,8 @@
|
||||||
using LLama.Abstractions;
|
using LLama.Abstractions;
|
||||||
using System;
|
|
||||||
using System.Text;
|
using System.Text;
|
||||||
using System.Text.Json;
|
|
||||||
using System.Text.Json.Serialization;
|
using System.Text.Json.Serialization;
|
||||||
using LLama.Native;
|
using LLama.Native;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
|
||||||
namespace LLama.Common
|
namespace LLama.Common
|
||||||
{
|
{
|
||||||
|
@ -55,6 +54,9 @@ namespace LLama.Common
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public TensorSplitsCollection TensorSplits { get; set; } = new();
|
public TensorSplitsCollection TensorSplits { get; set; } = new();
|
||||||
|
|
||||||
|
/// <inheritdoc />
|
||||||
|
public List<MetadataOverride> MetadataOverrides { get; } = new();
|
||||||
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public float? RopeFrequencyBase { get; set; }
|
public float? RopeFrequencyBase { get; set; }
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
using System.IO;
|
using System.IO;
|
||||||
using System;
|
using System;
|
||||||
using System.Buffers;
|
using System.Text;
|
||||||
using LLama.Abstractions;
|
using LLama.Abstractions;
|
||||||
using LLama.Native;
|
using LLama.Native;
|
||||||
|
|
||||||
|
@ -36,18 +36,44 @@ public static class IModelParamsExtensions
|
||||||
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
|
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
//todo: MetadataOverrides
|
if (@params.MetadataOverrides.Count == 0)
|
||||||
//if (@params.MetadataOverrides.Count == 0)
|
{
|
||||||
//{
|
unsafe
|
||||||
// unsafe
|
{
|
||||||
// {
|
result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero;
|
||||||
// result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero;
|
}
|
||||||
// }
|
}
|
||||||
//}
|
else
|
||||||
//else
|
{
|
||||||
//{
|
// Allocate enough space for all the override items
|
||||||
// throw new NotImplementedException("MetadataOverrides");
|
var overrides = new LLamaModelMetadataOverride[@params.MetadataOverrides.Count + 1];
|
||||||
//}
|
var overridesPin = overrides.AsMemory().Pin();
|
||||||
|
unsafe
|
||||||
|
{
|
||||||
|
result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overridesPin).Pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert each item
|
||||||
|
for (var i = 0; i < @params.MetadataOverrides.Count; i++)
|
||||||
|
{
|
||||||
|
var item = @params.MetadataOverrides[i];
|
||||||
|
var native = new LLamaModelMetadataOverride();
|
||||||
|
|
||||||
|
// Init value and tag
|
||||||
|
item.Write(ref native);
|
||||||
|
|
||||||
|
// Convert key to bytes
|
||||||
|
unsafe
|
||||||
|
{
|
||||||
|
fixed (char* srcKey = item.Key)
|
||||||
|
{
|
||||||
|
Encoding.UTF8.GetBytes(srcKey, 0, native.key, 128);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
overrides[i] = native;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return disposer;
|
return disposer;
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ public unsafe struct LLamaModelMetadataOverride
|
||||||
/// Key to override
|
/// Key to override
|
||||||
/// </summary>
|
/// </summary>
|
||||||
[FieldOffset(0)]
|
[FieldOffset(0)]
|
||||||
public fixed char key[128];
|
public fixed byte key[128];
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Type of value
|
/// Type of value
|
||||||
|
|
Loading…
Reference in New Issue