110 lines
3.8 KiB
C#
110 lines
3.8 KiB
C#
using LLama;
|
|
using LLama.Common;
|
|
using Microsoft.KernelMemory;
|
|
using Microsoft.KernelMemory.AI;
|
|
|
|
namespace LLamaSharp.KernelMemory
|
|
{
|
|
/// <summary>
|
|
/// Provides text embedding generation for LLamaSharp.
|
|
/// </summary>
|
|
public class LLamaSharpTextEmbeddingGenerator
|
|
: ITextEmbeddingGenerator, IDisposable
|
|
{
|
|
private readonly LLamaSharpConfig? _config;
|
|
private readonly LLamaWeights? _weights;
|
|
private readonly LLamaEmbedder _embedder;
|
|
private bool _ownsEmbedder = false;
|
|
private bool _ownsWeights = false;
|
|
|
|
/// <inheritdoc/>
|
|
public int MaxTokens => (int?)_config?.ContextSize ?? 2048;
|
|
|
|
/// <summary>
|
|
/// Initializes a new instance of the <see cref="LLamaSharpTextEmbeddingGenerator"/> class.
|
|
/// </summary>
|
|
/// <param name="config">The configuration for LLamaSharp.</param>
|
|
public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config)
|
|
{
|
|
this._config = config;
|
|
var @params = new ModelParams(_config.ModelPath)
|
|
{
|
|
Embeddings = true,
|
|
MainGpu = _config.MainGpu,
|
|
SplitMode = _config.SplitMode
|
|
};
|
|
_weights = LLamaWeights.LoadFromFile(@params);
|
|
_embedder = new LLamaEmbedder(_weights, @params);
|
|
_ownsWeights = true;
|
|
_ownsEmbedder = true;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Initializes a new instance of the <see cref="LLamaSharpTextEmbeddingGenerator"/> class from reused weights.
|
|
/// </summary>
|
|
/// <param name="config">The configuration for LLamaSharp.</param>
|
|
/// <param name="weights">A LLamaWeights object.</param>
|
|
public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights weights)
|
|
{
|
|
this._config = config;
|
|
var @params = new ModelParams(_config.ModelPath)
|
|
{
|
|
Embeddings = true,
|
|
MainGpu = _config.MainGpu,
|
|
SplitMode = _config.SplitMode
|
|
};
|
|
_weights = weights;
|
|
_embedder = new LLamaEmbedder(_weights, @params);
|
|
_ownsEmbedder = true;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Initializes a new instance of the <see cref="LLamaSharpTextEmbeddingGenerator"/> class from reused embedder.
|
|
/// </summary>
|
|
/// <param name="embedder">A LLamaEmbedder object.</param>
|
|
public LLamaSharpTextEmbeddingGenerator(LLamaEmbedder embedder)
|
|
{
|
|
this._config = null;
|
|
this._weights = null;
|
|
_embedder = embedder;
|
|
}
|
|
|
|
/// <inheritdoc/>
|
|
public void Dispose()
|
|
{
|
|
if (_ownsWeights)
|
|
{
|
|
_weights?.Dispose();
|
|
}
|
|
if (_ownsEmbedder)
|
|
{
|
|
_embedder.Dispose();
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc/>
|
|
public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default)
|
|
{
|
|
IList<ReadOnlyMemory<float>> results = new List<ReadOnlyMemory<float>>();
|
|
|
|
foreach (var d in data)
|
|
{
|
|
var embeddings = await _embedder.GetEmbeddings(d, cancellationToken);
|
|
results.Add(new ReadOnlyMemory<float>(embeddings));
|
|
}
|
|
|
|
return results;
|
|
}
|
|
|
|
/// <inheritdoc/>
|
|
public async Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
|
|
{
|
|
var embeddings = await _embedder.GetEmbeddings(text, cancellationToken);
|
|
return new Embedding(embeddings);
|
|
}
|
|
|
|
/// <inheritdoc/>
|
|
public int CountTokens(string text) => _embedder.Context.Tokenize(text, special: true).Length;
|
|
}
|
|
}
|