Added metadata overrides to `IModelParams`
This commit is contained in:
parent
b22d8b7495
commit
b868b056f7
|
@ -1,5 +1,6 @@
|
|||
using System.Diagnostics;
|
||||
using System.Text;
|
||||
using LLama.Abstractions;
|
||||
using LLama.Common;
|
||||
using LLama.Native;
|
||||
|
||||
|
@ -30,6 +31,7 @@ public class BatchedDecoding
|
|||
|
||||
// Load model
|
||||
var parameters = new ModelParams(modelPath);
|
||||
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
|
||||
// Tokenize prompt
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
<Import Project="..\LLama\LLamaSharp.Runtime.targets" />
|
||||
<PropertyGroup>
|
||||
<OutputType>Exe</OutputType>
|
||||
<TargetFrameworks>net6.0;net7.0;net8.0</TargetFrameworks>
|
||||
<TargetFrameworks>net6.0;net8.0</TargetFrameworks>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<Platforms>AnyCPU;x64</Platforms>
|
||||
|
|
|
@ -10,8 +10,7 @@ Console.WriteLine("=============================================================
|
|||
NativeLibraryConfig
|
||||
.Instance
|
||||
.WithCuda()
|
||||
.WithLogs()
|
||||
.WithAvx(NativeLibraryConfig.AvxLevel.Avx512);
|
||||
.WithLogs();
|
||||
|
||||
NativeApi.llama_empty_call();
|
||||
Console.WriteLine();
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
using LLama.Common;
|
||||
using System.Text.Json;
|
||||
using LLama.Abstractions;
|
||||
|
||||
namespace LLama.Unittest
|
||||
{
|
||||
|
@ -14,7 +15,12 @@ namespace LLama.Unittest
|
|||
ContextSize = 42,
|
||||
Seed = 42,
|
||||
GpuLayerCount = 111,
|
||||
TensorSplits = { [0] = 3 }
|
||||
TensorSplits = { [0] = 3 },
|
||||
MetadataOverrides =
|
||||
{
|
||||
MetadataOverride.Create("hello", true),
|
||||
MetadataOverride.Create("world", 17),
|
||||
}
|
||||
};
|
||||
|
||||
var json = JsonSerializer.Serialize(expected);
|
||||
|
|
|
@ -59,6 +59,9 @@ namespace LLama.Web.Common
|
|||
/// <inheritdoc />
|
||||
public TensorSplitsCollection TensorSplits { get; set; } = new();
|
||||
|
||||
/// <inheritdoc />
|
||||
public List<MetadataOverride> MetadataOverrides { get; } = new();
|
||||
|
||||
/// <inheritdoc />
|
||||
public float? RopeFrequencyBase { get; set; }
|
||||
|
||||
|
|
|
@ -59,6 +59,11 @@ namespace LLama.Abstractions
|
|||
/// base model path for the lora adapter (lora_base)
|
||||
/// </summary>
|
||||
string LoraBase { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Override specific metadata items in the model
|
||||
/// </summary>
|
||||
List<MetadataOverride> MetadataOverrides { get; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -186,7 +191,7 @@ namespace LLama.Abstractions
|
|||
: JsonConverter<TensorSplitsCollection>
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
|
||||
public override TensorSplitsCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
|
||||
{
|
||||
var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
|
||||
return new TensorSplitsCollection(arr);
|
||||
|
@ -198,4 +203,97 @@ namespace LLama.Abstractions
|
|||
JsonSerializer.Serialize(writer, value.Splits, options);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// An override for a single key/value pair in model metadata
|
||||
/// </summary>
|
||||
[JsonConverter(typeof(MetadataOverrideConverter))]
|
||||
public abstract record MetadataOverride
|
||||
{
|
||||
/// <summary>
|
||||
/// Create a new override for an int key
|
||||
/// </summary>
|
||||
/// <param name="key"></param>
|
||||
/// <param name="value"></param>
|
||||
/// <returns></returns>
|
||||
public static MetadataOverride Create(string key, int value)
|
||||
{
|
||||
return new IntOverride(key, value);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new override for a float key
|
||||
/// </summary>
|
||||
/// <param name="key"></param>
|
||||
/// <param name="value"></param>
|
||||
/// <returns></returns>
|
||||
public static MetadataOverride Create(string key, float value)
|
||||
{
|
||||
return new FloatOverride(key, value);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new override for a boolean key
|
||||
/// </summary>
|
||||
/// <param name="key"></param>
|
||||
/// <param name="value"></param>
|
||||
/// <returns></returns>
|
||||
public static MetadataOverride Create(string key, bool value)
|
||||
{
|
||||
return new BoolOverride(key, value);
|
||||
}
|
||||
|
||||
internal abstract void Write(ref LLamaModelMetadataOverride dest);
|
||||
|
||||
/// <summary>
|
||||
/// Get the key being overriden by this override
|
||||
/// </summary>
|
||||
public abstract string Key { get; init; }
|
||||
|
||||
private record IntOverride(string Key, int Value) : MetadataOverride
|
||||
{
|
||||
internal override void Write(ref LLamaModelMetadataOverride dest)
|
||||
{
|
||||
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
|
||||
dest.IntValue = Value;
|
||||
}
|
||||
}
|
||||
|
||||
private record FloatOverride(string Key, float Value) : MetadataOverride
|
||||
{
|
||||
internal override void Write(ref LLamaModelMetadataOverride dest)
|
||||
{
|
||||
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
|
||||
dest.FloatValue = Value;
|
||||
}
|
||||
}
|
||||
|
||||
private record BoolOverride(string Key, bool Value) : MetadataOverride
|
||||
{
|
||||
internal override void Write(ref LLamaModelMetadataOverride dest)
|
||||
{
|
||||
dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
|
||||
dest.BoolValue = Value ? -1 : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class MetadataOverrideConverter
|
||||
: JsonConverter<MetadataOverride>
|
||||
{
|
||||
/// <inheritdoc/>
|
||||
public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
//var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
|
||||
//return new TensorSplitsCollection(arr);
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
//JsonSerializer.Serialize(writer, value.Splits, options);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,9 +1,8 @@
|
|||
using LLama.Abstractions;
|
||||
using System;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using LLama.Native;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace LLama.Common
|
||||
{
|
||||
|
@ -55,6 +54,9 @@ namespace LLama.Common
|
|||
/// <inheritdoc />
|
||||
public TensorSplitsCollection TensorSplits { get; set; } = new();
|
||||
|
||||
/// <inheritdoc />
|
||||
public List<MetadataOverride> MetadataOverrides { get; } = new();
|
||||
|
||||
/// <inheritdoc />
|
||||
public float? RopeFrequencyBase { get; set; }
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
using System.IO;
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Text;
|
||||
using LLama.Abstractions;
|
||||
using LLama.Native;
|
||||
|
||||
|
@ -36,18 +36,44 @@ public static class IModelParamsExtensions
|
|||
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
|
||||
}
|
||||
|
||||
//todo: MetadataOverrides
|
||||
//if (@params.MetadataOverrides.Count == 0)
|
||||
//{
|
||||
// unsafe
|
||||
// {
|
||||
// result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero;
|
||||
// }
|
||||
//}
|
||||
//else
|
||||
//{
|
||||
// throw new NotImplementedException("MetadataOverrides");
|
||||
//}
|
||||
if (@params.MetadataOverrides.Count == 0)
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
result.kv_overrides = (LLamaModelMetadataOverride*)IntPtr.Zero;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Allocate enough space for all the override items
|
||||
var overrides = new LLamaModelMetadataOverride[@params.MetadataOverrides.Count + 1];
|
||||
var overridesPin = overrides.AsMemory().Pin();
|
||||
unsafe
|
||||
{
|
||||
result.kv_overrides = (LLamaModelMetadataOverride*)disposer.Add(overridesPin).Pointer;
|
||||
}
|
||||
|
||||
// Convert each item
|
||||
for (var i = 0; i < @params.MetadataOverrides.Count; i++)
|
||||
{
|
||||
var item = @params.MetadataOverrides[i];
|
||||
var native = new LLamaModelMetadataOverride();
|
||||
|
||||
// Init value and tag
|
||||
item.Write(ref native);
|
||||
|
||||
// Convert key to bytes
|
||||
unsafe
|
||||
{
|
||||
fixed (char* srcKey = item.Key)
|
||||
{
|
||||
Encoding.UTF8.GetBytes(srcKey, 0, native.key, 128);
|
||||
}
|
||||
}
|
||||
|
||||
overrides[i] = native;
|
||||
}
|
||||
}
|
||||
|
||||
return disposer;
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ public unsafe struct LLamaModelMetadataOverride
|
|||
/// Key to override
|
||||
/// </summary>
|
||||
[FieldOffset(0)]
|
||||
public fixed char key[128];
|
||||
public fixed byte key[128];
|
||||
|
||||
/// <summary>
|
||||
/// Type of value
|
||||
|
|
Loading…
Reference in New Issue