Normalizing embeddings in `LLamaEmbedder`. As is done in llama.cpp: 2891c8aa9a/examples/embedding/embedding.cpp (L92)
This commit is contained in:
parent
a5eba9463f
commit
d47b6afe4d
|
@ -25,18 +25,6 @@ public sealed class LLamaEmbedderTests
|
|||
_embedder.Dispose();
|
||||
}
|
||||
|
||||
private static float Magnitude(float[] a)
|
||||
{
|
||||
return MathF.Sqrt(a.Zip(a, (x, y) => x * y).Sum());
|
||||
}
|
||||
|
||||
private static void Normalize(float[] a)
|
||||
{
|
||||
var mag = Magnitude(a);
|
||||
for (var i = 0; i < a.Length; i++)
|
||||
a[i] /= mag;
|
||||
}
|
||||
|
||||
private static float Dot(float[] a, float[] b)
|
||||
{
|
||||
Assert.Equal(a.Length, b.Length);
|
||||
|
@ -46,21 +34,16 @@ public sealed class LLamaEmbedderTests
|
|||
[Fact]
|
||||
public async Task EmbedCompare()
|
||||
{
|
||||
var cat = await _embedder.GetEmbeddings("cat");
|
||||
var kitten = await _embedder.GetEmbeddings("kitten");
|
||||
var spoon = await _embedder.GetEmbeddings("spoon");
|
||||
|
||||
Normalize(cat);
|
||||
Normalize(kitten);
|
||||
Normalize(spoon);
|
||||
|
||||
var close = 1 - Dot(cat, kitten);
|
||||
var far = 1 - Dot(cat, spoon);
|
||||
|
||||
Assert.True(close < far);
|
||||
var cat = await _embedder.GetEmbeddings("The cat is cute");
|
||||
var kitten = await _embedder.GetEmbeddings("The kitten is kawaii");
|
||||
var spoon = await _embedder.GetEmbeddings("The spoon is not real");
|
||||
|
||||
_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
|
||||
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
|
||||
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
|
||||
|
||||
var close = 1 - Dot(cat, kitten);
|
||||
var far = 1 - Dot(cat, spoon);
|
||||
Assert.True(close < far);
|
||||
}
|
||||
}
|
|
@ -88,20 +88,35 @@ namespace LLama
|
|||
// Remove everything we just evaluated from the context cache
|
||||
Context.NativeHandle.KvCacheClear();
|
||||
|
||||
return embeddings;
|
||||
// Normalize the embeddings vector
|
||||
// https://github.com/ggerganov/llama.cpp/blob/2891c8aa9af17f4ff636ff3868bc34ff72b56e25/examples/embedding/embedding.cpp#L92
|
||||
Normalize(embeddings);
|
||||
|
||||
float[] GetEmbeddingsArray()
|
||||
{
|
||||
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
|
||||
if (embeddings == null)
|
||||
return Array.Empty<float>();
|
||||
return embeddings.ToArray();
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
///
|
||||
/// </summary>
|
||||
private float[] GetEmbeddingsArray()
|
||||
{
|
||||
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
|
||||
if (embeddings == null)
|
||||
return Array.Empty<float>();
|
||||
return embeddings.ToArray();
|
||||
}
|
||||
|
||||
private static void Normalize(Span<float> embeddings)
|
||||
{
|
||||
// Calculate length
|
||||
var lengthSqr = 0.0;
|
||||
foreach (var value in embeddings)
|
||||
lengthSqr += value * value;
|
||||
var length = (float)Math.Sqrt(lengthSqr);
|
||||
|
||||
// Normalize
|
||||
for (var i = 0; i < embeddings.Length; i++)
|
||||
embeddings[i] /= length;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
Context.Dispose();
|
||||
|
|
Loading…
Reference in New Issue