2023-05-13 02:08:03 +08:00
|
|
|
|
using LLama.Native;
|
|
|
|
|
using System;
|
|
|
|
|
using LLama.Exceptions;
|
2023-08-06 06:44:54 +08:00
|
|
|
|
using LLama.Abstractions;
|
2023-05-13 02:08:03 +08:00
|
|
|
|
|
2023-06-11 05:44:21 +08:00
|
|
|
|
namespace LLama
|
2023-05-13 02:08:03 +08:00
|
|
|
|
{
|
2023-06-20 02:38:57 +08:00
|
|
|
|
/// <summary>
|
|
|
|
|
/// The embedder for LLama, which supports getting embeddings from text.
|
|
|
|
|
/// </summary>
|
2023-08-25 00:03:27 +08:00
|
|
|
|
public sealed class LLamaEmbedder
|
2023-08-09 08:07:42 +08:00
|
|
|
|
: IDisposable
|
2023-05-13 02:08:03 +08:00
|
|
|
|
{
|
2023-08-24 08:14:12 +08:00
|
|
|
|
private readonly LLamaContext _ctx;
|
2023-08-09 08:07:42 +08:00
|
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
/// Dimension of embedding vectors
|
|
|
|
|
/// </summary>
|
|
|
|
|
public int EmbeddingSize => _ctx.EmbeddingSize;
|
2023-05-13 02:08:03 +08:00
|
|
|
|
|
2023-06-19 02:53:21 +08:00
|
|
|
|
/// <summary>
|
|
|
|
|
///
|
|
|
|
|
/// </summary>
|
|
|
|
|
/// <param name="params"></param>
|
2023-08-06 06:44:54 +08:00
|
|
|
|
public LLamaEmbedder(IModelParams @params)
|
2023-05-13 02:08:03 +08:00
|
|
|
|
{
|
2023-06-11 05:44:21 +08:00
|
|
|
|
@params.EmbeddingMode = true;
|
2023-08-24 08:14:12 +08:00
|
|
|
|
using var weights = LLamaWeights.LoadFromFile(@params);
|
|
|
|
|
_ctx = weights.CreateContext(@params);
|
2023-05-13 02:08:03 +08:00
|
|
|
|
}
|
|
|
|
|
|
2023-08-31 21:19:29 +08:00
|
|
|
|
public LLamaEmbedder(LLamaWeights weights, IModelParams @params)
|
2023-08-31 16:24:44 +08:00
|
|
|
|
{
|
2023-08-31 21:19:29 +08:00
|
|
|
|
_ctx = weights.CreateContext(@params);
|
2023-08-31 16:24:44 +08:00
|
|
|
|
}
|
|
|
|
|
|
2023-06-11 05:44:21 +08:00
|
|
|
|
/// <summary>
|
|
|
|
|
/// Get the embeddings of the text.
|
|
|
|
|
/// </summary>
|
|
|
|
|
/// <param name="text"></param>
|
2023-08-24 08:14:12 +08:00
|
|
|
|
/// <param name="threads">unused</param>
|
2023-06-11 05:44:21 +08:00
|
|
|
|
/// <param name="addBos">Add bos to the text.</param>
|
2023-08-24 08:14:12 +08:00
|
|
|
|
/// <param name="encoding">unused</param>
|
2023-06-11 05:44:21 +08:00
|
|
|
|
/// <returns></returns>
|
|
|
|
|
/// <exception cref="RuntimeError"></exception>
|
2023-08-24 08:14:12 +08:00
|
|
|
|
[Obsolete("'threads' and 'encoding' parameters are no longer used")]
|
|
|
|
|
// ReSharper disable once MethodOverloadWithOptionalParameter
|
|
|
|
|
public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
|
2023-05-13 02:08:03 +08:00
|
|
|
|
{
|
2023-08-24 08:14:12 +08:00
|
|
|
|
return GetEmbeddings(text, addBos);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
/// Get the embeddings of the text.
|
|
|
|
|
/// </summary>
|
|
|
|
|
/// <param name="text"></param>
|
|
|
|
|
/// <returns></returns>
|
|
|
|
|
/// <exception cref="RuntimeError"></exception>
|
|
|
|
|
public float[] GetEmbeddings(string text)
|
|
|
|
|
{
|
|
|
|
|
return GetEmbeddings(text, true);
|
|
|
|
|
}
|
2023-08-11 00:51:51 +08:00
|
|
|
|
|
2023-08-24 08:14:12 +08:00
|
|
|
|
/// <summary>
|
|
|
|
|
/// Get the embeddings of the text.
|
|
|
|
|
/// </summary>
|
|
|
|
|
/// <param name="text"></param>
|
|
|
|
|
/// <param name="addBos">Add bos to the text.</param>
|
|
|
|
|
/// <returns></returns>
|
|
|
|
|
/// <exception cref="RuntimeError"></exception>
|
|
|
|
|
public float[] GetEmbeddings(string text, bool addBos)
|
|
|
|
|
{
|
2023-07-20 23:29:54 +08:00
|
|
|
|
|
2023-08-24 08:14:12 +08:00
|
|
|
|
var embed_inp_array = _ctx.Tokenize(text, addBos);
|
2023-05-13 02:08:03 +08:00
|
|
|
|
|
|
|
|
|
// TODO(Rinne): deal with log of prompt
|
|
|
|
|
|
2023-07-20 23:29:54 +08:00
|
|
|
|
if (embed_inp_array.Length > 0)
|
2023-08-24 08:14:12 +08:00
|
|
|
|
_ctx.Eval(embed_inp_array, 0);
|
2023-05-13 02:08:03 +08:00
|
|
|
|
|
2023-08-24 08:14:12 +08:00
|
|
|
|
unsafe
|
2023-05-13 02:08:03 +08:00
|
|
|
|
{
|
2023-08-24 08:14:12 +08:00
|
|
|
|
var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle);
|
|
|
|
|
if (embeddings == null)
|
|
|
|
|
return Array.Empty<float>();
|
|
|
|
|
|
|
|
|
|
return new Span<float>(embeddings, EmbeddingSize).ToArray();
|
2023-05-13 02:08:03 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
2023-05-16 02:51:02 +08:00
|
|
|
|
|
2023-06-19 02:53:21 +08:00
|
|
|
|
/// <summary>
|
|
|
|
|
///
|
|
|
|
|
/// </summary>
|
2023-05-16 02:51:02 +08:00
|
|
|
|
public void Dispose()
|
|
|
|
|
{
|
|
|
|
|
_ctx.Dispose();
|
|
|
|
|
}
|
2023-05-13 02:08:03 +08:00
|
|
|
|
}
|
|
|
|
|
}
|