Merge branch 'master' into master
This commit is contained in:
commit
df9a549e64
|
@ -1,19 +1,16 @@
|
|||
using LLama.Common;
|
||||
using LLama.Native;
|
||||
using LLama.Common;
|
||||
using Xunit.Abstractions;
|
||||
using Xunit.Sdk;
|
||||
|
||||
namespace LLama.Unittest;
|
||||
|
||||
public sealed class LLamaEmbedderTests
|
||||
: IDisposable
|
||||
{
|
||||
private readonly ITestOutputHelper _testOutputHelper;
|
||||
private readonly LLamaEmbedder _embedder;
|
||||
|
||||
public LLamaEmbedderTests(ITestOutputHelper testOutputHelper)
|
||||
{
|
||||
_testOutputHelper = testOutputHelper;
|
||||
|
||||
var @params = new ModelParams(Constants.EmbeddingModelPath)
|
||||
{
|
||||
ContextSize = 4096,
|
||||
|
@ -36,17 +33,24 @@ public sealed class LLamaEmbedderTests
|
|||
return a.Zip(b, (x, y) => x * y).Sum();
|
||||
}
|
||||
|
||||
|
||||
[Fact]
|
||||
public async Task EmbedCompare()
|
||||
private async Task CompareEmbeddings(string modelPath)
|
||||
{
|
||||
var cat = await _embedder.GetEmbeddings("The cat is cute");
|
||||
var @params = new ModelParams(modelPath)
|
||||
{
|
||||
ContextSize = 8,
|
||||
Threads = 4,
|
||||
Embeddings = true,
|
||||
};
|
||||
using var weights = LLamaWeights.LoadFromFile(@params);
|
||||
using var embedder = new LLamaEmbedder(weights, @params);
|
||||
|
||||
var cat = await embedder.GetEmbeddings("The cat is cute");
|
||||
Assert.DoesNotContain(float.NaN, cat);
|
||||
|
||||
var kitten = await _embedder.GetEmbeddings("The kitten is kawaii");
|
||||
var kitten = await embedder.GetEmbeddings("The kitten is kawaii");
|
||||
Assert.DoesNotContain(float.NaN, kitten);
|
||||
|
||||
var spoon = await _embedder.GetEmbeddings("The spoon is not real");
|
||||
var spoon = await embedder.GetEmbeddings("The spoon is not real");
|
||||
Assert.DoesNotContain(float.NaN, spoon);
|
||||
|
||||
_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
|
||||
|
@ -62,4 +66,16 @@ public sealed class LLamaEmbedderTests
|
|||
|
||||
Assert.True(close < far);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task EmbedCompareEmbeddingModel()
|
||||
{
|
||||
await CompareEmbeddings(Constants.EmbeddingModelPath);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task EmbedCompareGenerateModel()
|
||||
{
|
||||
await CompareEmbeddings(Constants.GenerativeModelPath);
|
||||
}
|
||||
}
|
|
@ -97,11 +97,18 @@ namespace LLama
|
|||
|
||||
private float[] GetEmbeddingsArray()
|
||||
{
|
||||
var embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero);
|
||||
if (embeddings.Length == 0)
|
||||
return Array.Empty<float>();
|
||||
unsafe
|
||||
{
|
||||
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
|
||||
|
||||
return embeddings.ToArray();
|
||||
if (embeddings == null)
|
||||
embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero);
|
||||
|
||||
if (embeddings == null)
|
||||
return Array.Empty<float>();
|
||||
|
||||
return new Span<float>(embeddings, Context.EmbeddingSize).ToArray();
|
||||
}
|
||||
}
|
||||
|
||||
private static void Normalize(Span<float> embeddings)
|
||||
|
@ -112,6 +119,7 @@ namespace LLama
|
|||
lengthSqr += value * value;
|
||||
var length = (float)Math.Sqrt(lengthSqr);
|
||||
|
||||
// Do not divide by length if it is zero
|
||||
if (length <= float.Epsilon)
|
||||
return;
|
||||
|
||||
|
|
|
@ -137,41 +137,17 @@ namespace LLama.Native
|
|||
/// Get the embeddings for the a specific sequence.
|
||||
/// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public static Span<float> llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id)
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
var ptr = llama_get_embeddings_seq_native(ctx, id);
|
||||
if (ptr == null)
|
||||
return Array.Empty<float>();
|
||||
|
||||
return new Span<float>(ptr, ctx.EmbeddingSize);
|
||||
}
|
||||
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings_seq")]
|
||||
static extern unsafe float* llama_get_embeddings_seq_native(SafeLLamaContextHandle ctx, LLamaSeqId id);
|
||||
}
|
||||
/// <returns>A pointer to the first float in an embedding, length = ctx.EmbeddingSize</returns>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern unsafe float* llama_get_embeddings_seq(SafeLLamaContextHandle ctx, LLamaSeqId id);
|
||||
|
||||
/// <summary>
|
||||
/// Get the embeddings for the ith sequence.
|
||||
/// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public static Span<float> llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i)
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
var ptr = llama_get_embeddings_ith_native(ctx, i);
|
||||
if (ptr == null)
|
||||
return Array.Empty<float>();
|
||||
|
||||
return new Span<float>(ptr, ctx.EmbeddingSize);
|
||||
}
|
||||
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings_ith")]
|
||||
static extern unsafe float* llama_get_embeddings_ith_native(SafeLLamaContextHandle ctx, int i);
|
||||
}
|
||||
/// <returns>A pointer to the first float in an embedding, length = ctx.EmbeddingSize</returns>
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern unsafe float* llama_get_embeddings_ith(SafeLLamaContextHandle ctx, int i);
|
||||
|
||||
/// <summary>
|
||||
/// Get all output token embeddings.
|
||||
|
@ -182,20 +158,8 @@ namespace LLama.Native
|
|||
/// </summary>
|
||||
/// <param name="ctx"></param>
|
||||
/// <returns></returns>
|
||||
public static Span<float> llama_get_embeddings(SafeLLamaContextHandle ctx)
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
var ptr = llama_get_embeddings_native(ctx);
|
||||
if (ptr == null)
|
||||
return Array.Empty<float>();
|
||||
|
||||
return new Span<float>(ptr, ctx.EmbeddingSize);
|
||||
}
|
||||
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")]
|
||||
static extern unsafe float* llama_get_embeddings_native(SafeLLamaContextHandle ctx);
|
||||
}
|
||||
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
|
||||
public static extern unsafe float* llama_get_embeddings(SafeLLamaContextHandle ctx);
|
||||
|
||||
/// <summary>
|
||||
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
||||
|
|
Loading…
Reference in New Issue