LLamaSharp/LLama.Unittest/LLamaEmbedderTests.cs

67 lines
2.1 KiB
C#

using LLama.Common;
using Xunit.Abstractions;
namespace LLama.Unittest;
public sealed class LLamaEmbedderTests
{
private readonly ITestOutputHelper _testOutputHelper;
public LLamaEmbedderTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
}
private static float Dot(float[] a, float[] b)
{
Assert.Equal(a.Length, b.Length);
return a.Zip(b, (x, y) => x * y).Sum();
}
private async Task CompareEmbeddings(string modelPath)
{
var @params = new ModelParams(modelPath)
{
ContextSize = 8,
Threads = 4,
Embeddings = true,
GpuLayerCount = Constants.CIGpuLayerCount,
};
using var weights = LLamaWeights.LoadFromFile(@params);
using var embedder = new LLamaEmbedder(weights, @params);
var cat = await embedder.GetEmbeddings("The cat is cute");
Assert.DoesNotContain(float.NaN, cat);
var kitten = await embedder.GetEmbeddings("The kitten is kawaii");
Assert.DoesNotContain(float.NaN, kitten);
var spoon = await embedder.GetEmbeddings("The spoon is not real");
Assert.DoesNotContain(float.NaN, spoon);
_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
var close = 1 - Dot(cat, kitten);
var far = 1 - Dot(cat, spoon);
_testOutputHelper.WriteLine("");
_testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}");
_testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}");
Assert.True(close < far);
}
[Fact]
public async Task EmbedCompareEmbeddingModel()
{
await CompareEmbeddings(Constants.EmbeddingModelPath);
}
[Fact]
public async Task EmbedCompareGenerateModel()
{
await CompareEmbeddings(Constants.GenerativeModelPath);
}
}