Added metadata overrides to `IModelParams`

This commit is contained in:
Martin Evans 2023-12-14 02:05:40 +00:00
parent b22d8b7495
commit b868b056f7
9 changed files with 157 additions and 21 deletions

View File

@ -1,5 +1,6 @@
using System.Diagnostics;
using System.Text;
using LLama.Abstractions;
using LLama.Common;
using LLama.Native;
@ -30,6 +31,7 @@ public class BatchedDecoding
// Load model
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
// Tokenize prompt

View File

@ -2,7 +2,7 @@
<Import Project="..\LLama\LLamaSharp.Runtime.targets" />
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks>
<TargetFrameworks>net6.0;net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<Platforms>AnyCPU;x64</Platforms>

View File

@ -10,8 +10,7 @@ Console.WriteLine("=============================================================
NativeLibraryConfig
.Instance
.WithCuda()
.WithLogs()
.WithAvx(NativeLibraryConfig.AvxLevel.Avx512);
.WithLogs();
NativeApi.llama_empty_call();
Console.WriteLine();

View File

@ -1,5 +1,6 @@
using LLama.Common;
using System.Text.Json;
using LLama.Abstractions;
namespace LLama.Unittest
{
@ -14,7 +15,12 @@ namespace LLama.Unittest
ContextSize = 42,
Seed = 42,
GpuLayerCount = 111,
TensorSplits = { [0] = 3 }
TensorSplits = { [0] = 3 },
MetadataOverrides =
{
MetadataOverride.Create("hello", true),
MetadataOverride.Create("world", 17),
}
};
var json = JsonSerializer.Serialize(expected);

View File

@ -59,6 +59,9 @@ namespace LLama.Web.Common
/// <inheritdoc />
public TensorSplitsCollection TensorSplits { get; set; } = new();
/// <inheritdoc />
public List<MetadataOverride> MetadataOverrides { get; } = new();
/// <inheritdoc />
public float? RopeFrequencyBase { get; set; }

View File

@ -59,6 +59,11 @@ namespace LLama.Abstractions
/// base model path for the lora adapter (lora_base)
/// </summary>
string LoraBase { get; set; }
/// <summary>
/// Override specific metadata items in the model
/// </summary>
List<MetadataOverride> MetadataOverrides { get; }
}
/// <summary>
@ -186,7 +191,7 @@ namespace LLama.Abstractions
: JsonConverter<TensorSplitsCollection>
{
/// <inheritdoc/>
public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
return new TensorSplitsCollection(arr);
@ -198,4 +203,97 @@ namespace LLama.Abstractions
JsonSerializer.Serialize(writer, value.Splits, options);
}
}
/// <summary>
/// An override for a single key/value pair in model metadata
/// </summary>
[JsonConverter(typeof(MetadataOverrideConverter))]
public abstract record MetadataOverride
{
/// <summary>
/// Create a new override for an int key
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
/// <returns></returns>
public static MetadataOverride Create(string key, int value)
{
return new IntOverride(key, value);
}
/// <summary>
/// Create a new override for a float key
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
/// <returns></returns>
public static MetadataOverride Create(string key, float value)
{
return new FloatOverride(key, value);
}
/// <summary>
/// Create a new override for a boolean key
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
/// <returns></returns>
public static MetadataOverride Create(string key, bool value)
{
return new BoolOverride(key, value);
}
internal abstract void Write(ref LLamaModelMetadataOverride dest);
/// <summary>
/// Get the key being overriden by this override
/// </summary>
public abstract string Key { get; init; }
private record IntOverride(string Key, int Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
dest.IntValue = Value;
}
}
private record FloatOverride(string Key, float Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
dest.FloatValue = Value;
}
}
private record BoolOverride(string Key, bool Value) : MetadataOverride
{
internal override void Write(ref LLamaModelMetadataOverride dest)
{
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
dest.BoolValue = Value ? -1 : 0;
}
}
}
public class MetadataOverrideConverter
: JsonConverter<MetadataOverride>
{
/// <inheritdoc/>
public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
throw new NotImplementedException();
//var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
//return new TensorSplitsCollection(arr);
}
/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
{
throw new NotImplementedException();
//JsonSerializer.Serialize(writer, value.Splits, options);
}
}
}

View File

@ -1,9 +1,8 @@
using LLama.Abstractions;
using System;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Native;
using System.Collections.Generic;
namespace LLama.Common
{
@ -55,6 +54,9 @@ namespace LLama.Common
/// <inheritdoc />
public TensorSplitsCollection TensorSplits { get; set; } = new();
/// <inheritdoc />
public List<MetadataOverride> MetadataOverrides { get; } = new();
/// <inheritdoc />
public float? RopeFrequencyBase { get; set; }

View File

@ -1,6 +1,6 @@
using System.IO;
using System;
using System.Buffers;
using System.Text;
using LLama.Abstractions;
using LLama.Native;
@ -36,18 +36,44 @@ public static class IModelParamsExtensions
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
}
//todo: MetadataOverrides
//if (@params.MetadataOverrides.Count == 0)
//{
// unsafe
// {
// result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero;
// }
//}
//else
//{
// throw new NotImplementedException("MetadataOverrides");
//}
if (@params.MetadataOverrides.Count == 0)
{
unsafe
{
result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero;
}
}
else
{
// Allocate enough space for all the override items
var overrides = new LLamaModelMetadataOverride[@params.MetadataOverrides.Count + 1];
var overridesPin = overrides.AsMemory().Pin();
unsafe
{
result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overridesPin).Pointer;
}
// Convert each item
for (var i = 0; i < @params.MetadataOverrides.Count; i++)
{
var item = @params.MetadataOverrides[i];
var native = new LLamaModelMetadataOverride();
// Init value and tag
item.Write(ref native);
// Convert key to bytes
unsafe
{
fixed (char* srcKey = item.Key)
{
Encoding.UTF8.GetBytes(srcKey, 0, native.key, 128);
}
}
overrides[i] = native;
}
}
return disposer;
}

View File

@ -12,7 +12,7 @@ public unsafe struct LLamaModelMetadataOverride
/// Key to override
/// </summary>
[FieldOffset(0)]
public fixed char key[128];
public fixed byte key[128];
/// <summary>
/// Type of value