67 lines
2.1 KiB
C#
67 lines
2.1 KiB
C#
using LLama.Common;
|
|
using Xunit.Abstractions;
|
|
|
|
namespace LLama.Unittest;
|
|
|
|
public sealed class LLamaEmbedderTests
|
|
{
|
|
private readonly ITestOutputHelper _testOutputHelper;
|
|
|
|
public LLamaEmbedderTests(ITestOutputHelper testOutputHelper)
|
|
{
|
|
_testOutputHelper = testOutputHelper;
|
|
}
|
|
|
|
private static float Dot(float[] a, float[] b)
|
|
{
|
|
Assert.Equal(a.Length, b.Length);
|
|
return a.Zip(b, (x, y) => x * y).Sum();
|
|
}
|
|
|
|
private async Task CompareEmbeddings(string modelPath)
|
|
{
|
|
var @params = new ModelParams(modelPath)
|
|
{
|
|
ContextSize = 8,
|
|
Threads = 4,
|
|
Embeddings = true,
|
|
GpuLayerCount = Constants.CIGpuLayerCount,
|
|
};
|
|
using var weights = LLamaWeights.LoadFromFile(@params);
|
|
using var embedder = new LLamaEmbedder(weights, @params);
|
|
|
|
var cat = await embedder.GetEmbeddings("The cat is cute");
|
|
Assert.DoesNotContain(float.NaN, cat);
|
|
|
|
var kitten = await embedder.GetEmbeddings("The kitten is kawaii");
|
|
Assert.DoesNotContain(float.NaN, kitten);
|
|
|
|
var spoon = await embedder.GetEmbeddings("The spoon is not real");
|
|
Assert.DoesNotContain(float.NaN, spoon);
|
|
|
|
_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
|
|
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
|
|
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
|
|
|
|
var close = 1 - Dot(cat, kitten);
|
|
var far = 1 - Dot(cat, spoon);
|
|
|
|
_testOutputHelper.WriteLine("");
|
|
_testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}");
|
|
_testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}");
|
|
|
|
Assert.True(close < far);
|
|
}
|
|
|
|
[Fact]
|
|
public async Task EmbedCompareEmbeddingModel()
|
|
{
|
|
await CompareEmbeddings(Constants.EmbeddingModelPath);
|
|
}
|
|
|
|
[Fact]
|
|
public async Task EmbedCompareGenerateModel()
|
|
{
|
|
await CompareEmbeddings(Constants.GenerativeModelPath);
|
|
}
|
|
} |