Normalizing embeddings in `LLamaEmbedder`. As is done in llama.cpp: 2891c8aa9a/examples/embedding/embedding.cpp (L92)

This commit is contained in:
Martin Evans 2024-02-13 02:09:35 +00:00
parent a5eba9463f
commit d47b6afe4d
2 changed files with 33 additions and 35 deletions

View File

@ -25,18 +25,6 @@ public sealed class LLamaEmbedderTests
_embedder.Dispose();
}
private static float Magnitude(float[] a)
{
return MathF.Sqrt(a.Zip(a, (x, y) => x * y).Sum());
}
private static void Normalize(float[] a)
{
var mag = Magnitude(a);
for (var i = 0; i < a.Length; i++)
a[i] /= mag;
}
private static float Dot(float[] a, float[] b)
{
Assert.Equal(a.Length, b.Length);
@ -46,21 +34,16 @@ public sealed class LLamaEmbedderTests
[Fact]
public async Task EmbedCompare()
{
var cat = await _embedder.GetEmbeddings("cat");
var kitten = await _embedder.GetEmbeddings("kitten");
var spoon = await _embedder.GetEmbeddings("spoon");
Normalize(cat);
Normalize(kitten);
Normalize(spoon);
var close = 1 - Dot(cat, kitten);
var far = 1 - Dot(cat, spoon);
Assert.True(close < far);
var cat = await _embedder.GetEmbeddings("The cat is cute");
var kitten = await _embedder.GetEmbeddings("The kitten is kawaii");
var spoon = await _embedder.GetEmbeddings("The spoon is not real");
_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
var close = 1 - Dot(cat, kitten);
var far = 1 - Dot(cat, spoon);
Assert.True(close < far);
}
}

View File

@ -88,20 +88,35 @@ namespace LLama
// Remove everything we just evaluated from the context cache
Context.NativeHandle.KvCacheClear();
return embeddings;
// Normalize the embeddings vector
// https://github.com/ggerganov/llama.cpp/blob/2891c8aa9af17f4ff636ff3868bc34ff72b56e25/examples/embedding/embedding.cpp#L92
Normalize(embeddings);
float[] GetEmbeddingsArray()
{
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();
return embeddings.ToArray();
}
return embeddings;
}
/// <summary>
///
/// </summary>
private float[] GetEmbeddingsArray()
{
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();
return embeddings.ToArray();
}
private static void Normalize(Span<float> embeddings)
{
// Calculate length
var lengthSqr = 0.0;
foreach (var value in embeddings)
lengthSqr += value * value;
var length = (float)Math.Sqrt(lengthSqr);
// Normalize
for (var i = 0; i < embeddings.Length; i++)
embeddings[i] /= length;
}
/// <inheritdoc />
public void Dispose()
{
Context.Dispose();