Added metadata overrides to `IModelParams`

This commit is contained in:
Martin Evans 2023-12-14 02:05:40 +00:00
parent b22d8b7495
commit b868b056f7
9 changed files with 157 additions and 21 deletions

View File

@ -1,5 +1,6 @@
using System.Diagnostics; using System.Diagnostics;
using System.Text; using System.Text;
using LLama.Abstractions;
using LLama.Common; using LLama.Common;
using LLama.Native; using LLama.Native;
@ -30,6 +31,7 @@ public class BatchedDecoding
// Load model // Load model
var parameters = new ModelParams(modelPath); var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters); using var model = LLamaWeights.LoadFromFile(parameters);
// Tokenize prompt // Tokenize prompt

View File

@ -2,7 +2,7 @@
<Import Project="..\LLama\LLamaSharp.Runtime.targets" /> <Import Project="..\LLama\LLamaSharp.Runtime.targets" />
<PropertyGroup> <PropertyGroup>
<OutputType>Exe</OutputType> <OutputType>Exe</OutputType>
<TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks> <TargetFrameworks>net6.0;net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings> <ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
<Platforms>AnyCPU;x64</Platforms> <Platforms>AnyCPU;x64</Platforms>

View File

@ -10,8 +10,7 @@ Console.WriteLine("=============================================================
NativeLibraryConfig NativeLibraryConfig
.Instance .Instance
.WithCuda() .WithCuda()
.WithLogs() .WithLogs();
.WithAvx(NativeLibraryConfig.AvxLevel.Avx512);
NativeApi.llama_empty_call(); NativeApi.llama_empty_call();
Console.WriteLine(); Console.WriteLine();

View File

@ -1,5 +1,6 @@
using LLama.Common; using LLama.Common;
using System.Text.Json; using System.Text.Json;
using LLama.Abstractions;
namespace LLama.Unittest namespace LLama.Unittest
{ {
@ -14,7 +15,12 @@ namespace LLama.Unittest
ContextSize = 42, ContextSize = 42,
Seed = 42, Seed = 42,
GpuLayerCount = 111, GpuLayerCount = 111,
TensorSplits = { [0] = 3 } TensorSplits = { [0] = 3 },
MetadataOverrides =
{
MetadataOverride.Create("hello", true),
MetadataOverride.Create("world", 17),
}
}; };
var json = JsonSerializer.Serialize(expected); var json = JsonSerializer.Serialize(expected);

View File

@ -59,6 +59,9 @@ namespace LLama.Web.Common
/// <inheritdoc /> /// <inheritdoc />
public TensorSplitsCollection TensorSplits { get; set; } = new(); public TensorSplitsCollection TensorSplits { get; set; } = new();
/// <inheritdoc />
public List<MetadataOverride> MetadataOverrides { get; } = new();
/// <inheritdoc /> /// <inheritdoc />
public float? RopeFrequencyBase { get; set; } public float? RopeFrequencyBase { get; set; }

View File

@ -59,6 +59,11 @@ namespace LLama.Abstractions
/// base model path for the lora adapter (lora_base) /// base model path for the lora adapter (lora_base)
/// </summary> /// </summary>
string LoraBase { get; set; } string LoraBase { get; set; }
/// <summary>
/// Override specific metadata items in the model
/// </summary>
List<MetadataOverride> MetadataOverrides { get; }
} }
/// <summary> /// <summary>
@ -186,7 +191,7 @@ namespace LLama.Abstractions
: JsonConverter<TensorSplitsCollection> : JsonConverter<TensorSplitsCollection>
{ {
/// <inheritdoc/> /// <inheritdoc/>
public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{ {
var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>(); var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
return new TensorSplitsCollection(arr); return new TensorSplitsCollection(arr);
@ -198,4 +203,97 @@ namespace LLama.Abstractions
JsonSerializer.Serialize(writer, value.Splits, options); JsonSerializer.Serialize(writer, value.Splits, options);
} }
} }
/// <summary>
/// An override for a single key/value pair in model metadata
/// </summary>
[JsonConverter(typeof(MetadataOverrideConverter))]
public abstract record MetadataOverride
{
/// <summary>
/// Create a new override for an int key
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
/// <returns></returns>
public static MetadataOverride Create(string key, int value)
{
return new IntOverride(key, value);
}
/// <summary>
/// Create a new override for a float key
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
/// <returns></returns>
public static MetadataOverride Create(string key, float value)
{
return new FloatOverride(key, value);
}
/// <summary>
/// Create a new override for a boolean key
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
/// <returns></returns>
public static MetadataOverride Create(string key, bool value)
{
return new BoolOverride(key, value);
}
internal abstract void Write(ref LLamaModelMetadataOverride dest);
/// <summary>
/// Get the key being overriden by this override
/// </summary>
public abstract string Key { get; init; }
private record IntOverride(string Key, int Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
dest.IntValue = Value;
}
}
private record FloatOverride(string Key, float Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
dest.FloatValue = Value;
}
}
private record BoolOverride(string Key, bool Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
dest.BoolValue = Value ? -1 : 0;
}
}
}
public class MetadataOverrideConverter
: JsonConverter<MetadataOverride>
{
/// <inheritdoc/>
public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
throw new NotImplementedException();
//var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
//return new TensorSplitsCollection(arr);
}
/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
{
throw new NotImplementedException();
//JsonSerializer.Serialize(writer, value.Splits, options);
}
}
} }

View File

@ -1,9 +1,8 @@
using LLama.Abstractions; using LLama.Abstractions;
using System;
using System.Text; using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using LLama.Native; using LLama.Native;
using System.Collections.Generic;
namespace LLama.Common namespace LLama.Common
{ {
@ -55,6 +54,9 @@ namespace LLama.Common
/// <inheritdoc /> /// <inheritdoc />
public TensorSplitsCollection TensorSplits { get; set; } = new(); public TensorSplitsCollection TensorSplits { get; set; } = new();
/// <inheritdoc />
public List<MetadataOverride> MetadataOverrides { get; } = new();
/// <inheritdoc /> /// <inheritdoc />
public float? RopeFrequencyBase { get; set; } public float? RopeFrequencyBase { get; set; }

View File

@ -1,6 +1,6 @@
using System.IO; using System.IO;
using System; using System;
using System.Buffers; using System.Text;
using LLama.Abstractions; using LLama.Abstractions;
using LLama.Native; using LLama.Native;
@ -36,18 +36,44 @@ public static class IModelParamsExtensions
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer; result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
} }
//todo: MetadataOverrides if (@params.MetadataOverrides.Count == 0)
//if (@params.MetadataOverrides.Count == 0) {
//{ unsafe
// unsafe {
// { result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero;
// result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero; }
// } }
//} else
//else {
//{ // Allocate enough space for all the override items
// throw new NotImplementedException("MetadataOverrides"); var overrides = new LLamaModelMetadataOverride[@params.MetadataOverrides.Count + 1];
//} var overridesPin = overrides.AsMemory().Pin();
unsafe
{
result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overridesPin).Pointer;
}
// Convert each item
for (var i = 0; i < @params.MetadataOverrides.Count; i++)
{
var item = @params.MetadataOverrides[i];
var native = new LLamaModelMetadataOverride();
// Init value and tag
item.Write(ref native);
// Convert key to bytes
unsafe
{
fixed (char* srcKey = item.Key)
{
Encoding.UTF8.GetBytes(srcKey, 0, native.key, 128);
}
}
overrides[i] = native;
}
}
return disposer; return disposer;
} }

View File

@ -12,7 +12,7 @@ public unsafe struct LLamaModelMetadataOverride
/// Key to override /// Key to override
/// </summary> /// </summary>
[FieldOffset(0)] [FieldOffset(0)]
public fixed char key[128]; public fixed byte key[128];
/// <summary> /// <summary>
/// Type of value /// Type of value