127 lines
4.3 KiB
C#
127 lines
4.3 KiB
C#
using LLama.Native;
|
|
using System;
|
|
using LLama.Exceptions;
|
|
using LLama.Abstractions;
|
|
using Microsoft.Extensions.Logging;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
|
|
namespace LLama
|
|
{
|
|
/// <summary>
|
|
/// The embedder for LLama, which supports getting embeddings from text.
|
|
/// </summary>
|
|
public sealed class LLamaEmbedder
|
|
: IDisposable
|
|
{
|
|
/// <summary>
|
|
/// Dimension of embedding vectors
|
|
/// </summary>
|
|
public int EmbeddingSize => Context.EmbeddingSize;
|
|
|
|
/// <summary>
|
|
/// LLama Context
|
|
/// </summary>
|
|
public LLamaContext Context { get; }
|
|
|
|
/// <summary>
|
|
/// Create a new embedder, using the given LLamaWeights
|
|
/// </summary>
|
|
/// <param name="weights"></param>
|
|
/// <param name="params"></param>
|
|
/// <param name="logger"></param>
|
|
public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
|
|
{
|
|
if (!@params.EmbeddingMode)
|
|
throw new ArgumentException("EmbeddingMode must be true", nameof(@params));
|
|
|
|
Context = weights.CreateContext(@params, logger);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Get the embeddings of the text.
|
|
/// </summary>
|
|
/// <param name="text"></param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="RuntimeError"></exception>
|
|
public Task<float[]> GetEmbeddings(string text, CancellationToken cancellationToken = default)
|
|
{
|
|
return GetEmbeddings(text, true, cancellationToken);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Get the embeddings of the text.
|
|
/// </summary>
|
|
/// <param name="text"></param>
|
|
/// <param name="addBos">Add bos to the text.</param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
/// <exception cref="RuntimeError"></exception>
|
|
public async Task<float[]> GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default)
|
|
{
|
|
var tokens = Context.Tokenize(text, addBos);
|
|
if (tokens.Length > Context.ContextSize)
|
|
throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text));
|
|
|
|
// Evaluate prompt in batch-size chunks
|
|
var n_past = 0;
|
|
var batch = new LLamaBatch();
|
|
var batchSize = (int)Context.Params.BatchSize;
|
|
for (var i = 0; i < tokens.Length; i += batchSize)
|
|
{
|
|
var n_eval = tokens.Length - i;
|
|
if (n_eval > batchSize)
|
|
n_eval = batchSize;
|
|
|
|
batch.Clear();
|
|
batch.AddRange(tokens.AsSpan(i, n_eval), n_past, LLamaSeqId.Zero, false);
|
|
n_past += n_eval;
|
|
|
|
var returnCode = await Context.DecodeAsync(batch, cancellationToken);
|
|
if (returnCode != 0)
|
|
throw new LLamaDecodeError(returnCode);
|
|
}
|
|
|
|
var embeddings = GetEmbeddingsArray();
|
|
|
|
// Remove everything we just evaluated from the context cache
|
|
Context.NativeHandle.KvCacheClear();
|
|
|
|
// Normalize the embeddings vector
|
|
// https://github.com/ggerganov/llama.cpp/blob/2891c8aa9af17f4ff636ff3868bc34ff72b56e25/examples/embedding/embedding.cpp#L92
|
|
Normalize(embeddings);
|
|
|
|
return embeddings;
|
|
}
|
|
|
|
private float[] GetEmbeddingsArray()
|
|
{
|
|
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
|
|
if (embeddings == null)
|
|
return Array.Empty<float>();
|
|
return embeddings.ToArray();
|
|
}
|
|
|
|
private static void Normalize(Span<float> embeddings)
|
|
{
|
|
// Calculate length
|
|
var lengthSqr = 0.0;
|
|
foreach (var value in embeddings)
|
|
lengthSqr += value * value;
|
|
var length = (float)Math.Sqrt(lengthSqr);
|
|
|
|
// Normalize
|
|
for (var i = 0; i < embeddings.Length; i++)
|
|
embeddings[i] /= length;
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public void Dispose()
|
|
{
|
|
Context.Dispose();
|
|
}
|
|
|
|
}
|
|
}
|