219 lines
8.5 KiB
C#
219 lines
8.5 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.Text;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
using LLama.Abstractions;
|
|
using LLama.Exceptions;
|
|
using LLama.Extensions;
|
|
using LLama.Native;
|
|
using Microsoft.Extensions.Logging;
|
|
|
|
namespace LLama
|
|
{
|
|
/// <summary>
|
|
/// A set of model weights, loaded into memory.
|
|
/// </summary>
|
|
public sealed class LLamaWeights
|
|
: IDisposable
|
|
{
|
|
/// <summary>
|
|
/// The native handle, which is used in the native APIs
|
|
/// </summary>
|
|
/// <remarks>Be careful how you use this!</remarks>
|
|
public SafeLlamaModelHandle NativeHandle { get; }
|
|
|
|
/// <summary>
|
|
/// Total number of tokens in vocabulary of this model
|
|
/// </summary>
|
|
public int VocabCount => NativeHandle.VocabCount;
|
|
|
|
/// <summary>
|
|
/// Total number of tokens in the context
|
|
/// </summary>
|
|
public int ContextSize => NativeHandle.ContextSize;
|
|
|
|
/// <summary>
|
|
/// Get the size of this model in bytes
|
|
/// </summary>
|
|
public ulong SizeInBytes => NativeHandle.SizeInBytes;
|
|
|
|
/// <summary>
|
|
/// Get the number of parameters in this model
|
|
/// </summary>
|
|
public ulong ParameterCount => NativeHandle.ParameterCount;
|
|
|
|
/// <summary>
|
|
/// Dimension of embedding vectors
|
|
/// </summary>
|
|
public int EmbeddingSize => NativeHandle.EmbeddingSize;
|
|
|
|
/// <summary>
|
|
/// Get the special tokens of this model
|
|
/// </summary>
|
|
public SafeLlamaModelHandle.ModelTokens Tokens => NativeHandle.Tokens;
|
|
|
|
/// <summary>
|
|
/// All metadata keys in this model
|
|
/// </summary>
|
|
public IReadOnlyDictionary<string, string> Metadata { get; set; }
|
|
|
|
private LLamaWeights(SafeLlamaModelHandle weights)
|
|
{
|
|
NativeHandle = weights;
|
|
Metadata = weights.ReadMetadata();
|
|
}
|
|
|
|
/// <summary>
|
|
/// Load weights into memory
|
|
/// </summary>
|
|
/// <param name="params"></param>
|
|
/// <returns></returns>
|
|
public static LLamaWeights LoadFromFile(IModelParams @params)
|
|
{
|
|
using var pin = @params.ToLlamaModelParams(out var lparams);
|
|
var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
|
|
|
|
foreach (var adapter in @params.LoraAdapters)
|
|
{
|
|
if (string.IsNullOrEmpty(adapter.Path))
|
|
continue;
|
|
if (adapter.Scale <= 0)
|
|
continue;
|
|
|
|
weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase);
|
|
}
|
|
|
|
return new LLamaWeights(weights);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Load weights into memory
|
|
/// </summary>
|
|
/// <param name="params">Parameters to use to load the model</param>
|
|
/// <param name="token">A cancellation token that can interrupt model loading</param>
|
|
/// <param name="progressReporter">Receives progress updates as the model loads (0 to 1)</param>
|
|
/// <returns></returns>
|
|
/// <exception cref="LoadWeightsFailedException">Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.</exception>
|
|
/// <exception cref="OperationCanceledException">Thrown if the cancellation token is cancelled.</exception>
|
|
public static async Task<LLamaWeights> LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress<float>? progressReporter = null)
|
|
{
|
|
// don't touch the @params object inside the task, it might be changed
|
|
// externally! Save a copy of everything that we need later.
|
|
var modelPath = @params.ModelPath;
|
|
var loraBase = @params.LoraBase;
|
|
var loraAdapters = @params.LoraAdapters.ToArray();
|
|
|
|
// Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a
|
|
// slightly smaller range to allow some space for reporting LoRA loading too.
|
|
var modelLoadProgressRange = 1f;
|
|
if (loraAdapters.Length > 0)
|
|
modelLoadProgressRange = 0.9f;
|
|
|
|
using (@params.ToLlamaModelParams(out var lparams))
|
|
{
|
|
#if !NETSTANDARD2_0
|
|
// Overwrite the progress callback with one which polls the cancellation token and updates the progress object
|
|
if (token.CanBeCanceled || progressReporter != null)
|
|
{
|
|
var internalCallback = lparams.progress_callback;
|
|
lparams.progress_callback = (progress, ctx) =>
|
|
{
|
|
// Update the progress reporter (remapping the value into the smaller range).
|
|
progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange);
|
|
|
|
// If the user set a callback in the model params, call that and see if we should cancel
|
|
if (internalCallback != null && !internalCallback(progress, ctx))
|
|
return false;
|
|
|
|
// Check the cancellation token
|
|
if (token.IsCancellationRequested)
|
|
return false;
|
|
|
|
return true;
|
|
};
|
|
}
|
|
#endif
|
|
|
|
var model = await Task.Run(() =>
|
|
{
|
|
try
|
|
{
|
|
// Load the model
|
|
var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams);
|
|
|
|
// Apply the LoRA adapters
|
|
for (var i = 0; i < loraAdapters.Length; i++)
|
|
{
|
|
// Interrupt applying LoRAs if the token is cancelled
|
|
if (token.IsCancellationRequested)
|
|
{
|
|
weights.Dispose();
|
|
token.ThrowIfCancellationRequested();
|
|
}
|
|
|
|
// Don't apply invalid adapters
|
|
var adapter = loraAdapters[i];
|
|
if (string.IsNullOrEmpty(adapter.Path))
|
|
continue;
|
|
if (adapter.Scale <= 0)
|
|
continue;
|
|
|
|
weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase);
|
|
|
|
// Report progress. Model loading reported progress from 0 -> 0.9, use
|
|
// the last 0.1 to represent all of the LoRA adapters being applied.
|
|
progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1));
|
|
}
|
|
|
|
// Update progress reporter to indicate completion
|
|
progressReporter?.Report(1);
|
|
|
|
return new LLamaWeights(weights);
|
|
}
|
|
catch (LoadWeightsFailedException)
|
|
{
|
|
// Convert a LoadWeightsFailedException into a cancellation exception if possible.
|
|
token.ThrowIfCancellationRequested();
|
|
|
|
// Ok the weights failed to load for some reason other than cancellation.
|
|
throw;
|
|
}
|
|
}, token);
|
|
|
|
return model;
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public void Dispose()
|
|
{
|
|
NativeHandle.Dispose();
|
|
}
|
|
|
|
/// <summary>
|
|
/// Create a llama_context using this model
|
|
/// </summary>
|
|
/// <param name="params"></param>
|
|
/// <param name="logger"></param>
|
|
/// <returns></returns>
|
|
public LLamaContext CreateContext(IContextParams @params, ILogger? logger = null)
|
|
{
|
|
return new LLamaContext(this, @params, logger);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Convert a string of text into tokens
|
|
/// </summary>
|
|
/// <param name="text"></param>
|
|
/// <param name="add_bos"></param>
|
|
/// <param name="encoding"></param>
|
|
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
|
|
/// <returns></returns>
|
|
public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding)
|
|
{
|
|
return NativeHandle.Tokenize(text, add_bos, special, encoding);
|
|
}
|
|
}
|
|
}
|