Merge pull request #702 from martindevans/interruptible_async_model_load

Interruptible Async Model Loading With Progress Monitoring
This commit is contained in:
Martin Evans 2024-04-27 16:06:40 +01:00 committed by GitHub
commit 84bb5a36ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 135 additions and 29 deletions

View File

@ -19,7 +19,7 @@ public class BatchedExecutorFork
string modelPath = UserSettings.GetModelPath(); string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath); var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

View File

@ -19,7 +19,7 @@ public class BatchedExecutorGuidance
string modelPath = UserSettings.GetModelPath(); string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath); var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim(); var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim(); var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();

View File

@ -20,7 +20,7 @@ public class BatchedExecutorRewind
string modelPath = UserSettings.GetModelPath(); string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath); var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

View File

@ -18,7 +18,7 @@ public class BatchedExecutorSaveAndLoad
string modelPath = UserSettings.GetModelPath(); string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath); var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

View File

@ -31,7 +31,7 @@ public class ChatChineseGB2312
GpuLayerCount = 5, GpuLayerCount = 5,
Encoding = Encoding.UTF8 Encoding = Encoding.UTF8
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context); var executor = new InteractiveExecutor(context);

View File

@ -15,11 +15,11 @@ public class ChatSessionStripRoleName
Seed = 1337, Seed = 1337,
GpuLayerCount = 5 GpuLayerCount = 5
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context); var executor = new InteractiveExecutor(context);
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
ChatSession session = new(executor, chatHistory); ChatSession session = new(executor, chatHistory);

View File

@ -13,7 +13,7 @@ public class ChatSessionWithHistory
Seed = 1337, Seed = 1337,
GpuLayerCount = 5 GpuLayerCount = 5
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context); var executor = new InteractiveExecutor(context);

View File

@ -13,11 +13,11 @@ public class ChatSessionWithRestart
Seed = 1337, Seed = 1337,
GpuLayerCount = 5 GpuLayerCount = 5
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context); var executor = new InteractiveExecutor(context);
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
ChatSession prototypeSession = ChatSession prototypeSession =
await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory); await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory);

View File

@ -13,11 +13,11 @@ public class ChatSessionWithRoleName
Seed = 1337, Seed = 1337,
GpuLayerCount = 5 GpuLayerCount = 5
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context); var executor = new InteractiveExecutor(context);
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
ChatSession session = new(executor, chatHistory); ChatSession session = new(executor, chatHistory);

View File

@ -29,7 +29,7 @@
{ {
ContextSize = 4096 ContextSize = 4096
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix, null); var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix, null);

View File

@ -9,7 +9,7 @@ namespace LLama.Examples.Examples
{ {
string modelPath = UserSettings.GetModelPath(); string modelPath = UserSettings.GetModelPath();
var gbnf = File.ReadAllText("Assets/json.gbnf").Trim(); var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim();
var grammar = Grammar.Parse(gbnf, "root"); var grammar = Grammar.Parse(gbnf, "root");
var parameters = new ModelParams(modelPath) var parameters = new ModelParams(modelPath)
@ -17,7 +17,7 @@ namespace LLama.Examples.Examples
Seed = 1337, Seed = 1337,
GpuLayerCount = 5 GpuLayerCount = 5
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var ex = new StatelessExecutor(model, parameters); var ex = new StatelessExecutor(model, parameters);
Console.ForegroundColor = ConsoleColor.Yellow; Console.ForegroundColor = ConsoleColor.Yellow;

View File

@ -9,14 +9,14 @@ namespace LLama.Examples.Examples
{ {
string modelPath = UserSettings.GetModelPath(); string modelPath = UserSettings.GetModelPath();
var prompt = File.ReadAllText("Assets/dan.txt").Trim(); var prompt = (await File.ReadAllTextAsync("Assets/dan.txt")).Trim();
var parameters = new ModelParams(modelPath) var parameters = new ModelParams(modelPath)
{ {
Seed = 1337, Seed = 1337,
GpuLayerCount = 5 GpuLayerCount = 5
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
var executor = new InstructExecutor(context); var executor = new InstructExecutor(context);

View File

@ -16,7 +16,7 @@ namespace LLama.Examples.Examples
Seed = 1337, Seed = 1337,
GpuLayerCount = 5 GpuLayerCount = 5
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context); var ex = new InteractiveExecutor(context);

View File

@ -20,7 +20,7 @@ namespace LLama.Examples.Examples
var parameters = new ModelParams(modelPath); var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
// Llava Init // Llava Init

View File

@ -15,7 +15,7 @@ namespace LLama.Examples.Examples
Seed = 1337, Seed = 1337,
GpuLayerCount = 5 GpuLayerCount = 5
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context); var ex = new InteractiveExecutor(context);

View File

@ -16,7 +16,7 @@ namespace LLama.Examples.Examples
Seed = 1337, Seed = 1337,
GpuLayerCount = 5 GpuLayerCount = 5
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context); var ex = new InteractiveExecutor(context);

View File

@ -16,7 +16,7 @@ namespace LLama.Examples.Examples
// Load weights into memory // Load weights into memory
var parameters = new ModelParams(modelPath); var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var ex = new StatelessExecutor(model, parameters); var ex = new StatelessExecutor(model, parameters);
var chatGPT = new LLamaSharpChatCompletion(ex); var chatGPT = new LLamaSharpChatCompletion(ex);

View File

@ -23,7 +23,7 @@ namespace LLama.Examples.Examples
Embeddings = true Embeddings = true
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var embedding = new LLamaEmbedder(model, parameters); var embedding = new LLamaEmbedder(model, parameters);
Console.WriteLine("===================================================="); Console.WriteLine("====================================================");

View File

@ -19,7 +19,7 @@ namespace LLama.Examples.Examples
// Load weights into memory // Load weights into memory
var parameters = new ModelParams(modelPath); var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var ex = new StatelessExecutor(model, parameters); var ex = new StatelessExecutor(model, parameters);
var builder = Kernel.CreateBuilder(); var builder = Kernel.CreateBuilder();

View File

@ -14,7 +14,7 @@ namespace LLama.Examples.Examples
Seed = 1337, Seed = 1337,
GpuLayerCount = 5 GpuLayerCount = 5
}; };
using var model = LLamaWeights.LoadFromFile(parameters); using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var ex = new StatelessExecutor(model, parameters); var ex = new StatelessExecutor(model, parameters);
Console.ForegroundColor = ConsoleColor.Yellow; Console.ForegroundColor = ConsoleColor.Yellow;

View File

@ -12,7 +12,7 @@ namespace LLama.Examples.Examples
// Load weights into memory // Load weights into memory
var @params = new ModelParams(modelPath); var @params = new ModelParams(modelPath);
using var weights = LLamaWeights.LoadFromFile(@params); using var weights = await LLamaWeights.LoadFromFileAsync(@params);
// Create 2 contexts sharing the same weights // Create 2 contexts sharing the same weights
using var aliceCtx = weights.CreateContext(@params); using var aliceCtx = weights.CreateContext(@params);

View File

@ -1,7 +1,10 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions; using LLama.Abstractions;
using LLama.Exceptions;
using LLama.Extensions; using LLama.Extensions;
using LLama.Native; using LLama.Native;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
@ -84,6 +87,104 @@ namespace LLama
return new LLamaWeights(weights); return new LLamaWeights(weights);
} }
/// <summary>
/// Load weights into memory
/// </summary>
/// <param name="params">Parameters to use to load the model</param>
/// <param name="token">A cancellation token that can interrupt model loading</param>
/// <param name="progressReporter">Receives progress updates as the model loads (0 to 1)</param>
/// <returns></returns>
/// <exception cref="LoadWeightsFailedException">Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.</exception>
/// <exception cref="OperationCanceledException">Thrown if the cancellation token is cancelled.</exception>
public static async Task<LLamaWeights> LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress<float>? progressReporter = null)
{
// don't touch the @params object inside the task, it might be changed
// externally! Save a copy of everything that we need later.
var modelPath = @params.ModelPath;
var loraBase = @params.LoraBase;
var loraAdapters = @params.LoraAdapters.ToArray();
// Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a
// slightly smaller range to allow some space for reporting LoRA loading too.
var modelLoadProgressRange = 1f;
if (loraAdapters.Length > 0)
modelLoadProgressRange = 0.9f;
using (@params.ToLlamaModelParams(out var lparams))
{
#if !NETSTANDARD2_0
// Overwrite the progress callback with one which polls the cancellation token and updates the progress object
if (token.CanBeCanceled || progressReporter != null)
{
var internalCallback = lparams.progress_callback;
lparams.progress_callback = (progress, ctx) =>
{
// Update the progress reporter (remapping the value into the smaller range).
progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange);
// If the user set a callback in the model params, call that and see if we should cancel
if (internalCallback != null && !internalCallback(progress, ctx))
return false;
// Check the cancellation token
if (token.IsCancellationRequested)
return false;
return true;
};
}
#endif
var model = await Task.Run(() =>
{
try
{
// Load the model
var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams);
// Apply the LoRA adapters
for (var i = 0; i < loraAdapters.Length; i++)
{
// Interrupt applying LoRAs if the token is cancelled
if (token.IsCancellationRequested)
{
weights.Dispose();
token.ThrowIfCancellationRequested();
}
// Don't apply invalid adapters
var adapter = loraAdapters[i];
if (string.IsNullOrEmpty(adapter.Path))
continue;
if (adapter.Scale <= 0)
continue;
weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase);
// Report progress. Model loading reported progress from 0 -> 0.9, use
// the last 0.1 to represent all of the LoRA adapters being applied.
progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1));
}
// Update progress reporter to indicate completion
progressReporter?.Report(1);
return new LLamaWeights(weights);
}
catch (LoadWeightsFailedException)
{
// Convert a LoadWeightsFailedException into a cancellation exception if possible.
token.ThrowIfCancellationRequested();
// Ok the weights failed to load for some reason other than cancellation.
throw;
}
}, token);
return model;
}
}
/// <inheritdoc /> /// <inheritdoc />
public void Dispose() public void Dispose()
{ {

View File

@ -8,6 +8,8 @@ namespace LLama.Native
/// </summary> /// </summary>
/// <param name="progress"></param> /// <param name="progress"></param>
/// <param name="ctx"></param> /// <param name="ctx"></param>
/// <returns>If the provided progress_callback returns true, model loading continues.
/// If it returns false, model loading is immediately aborted.</returns>
/// <remarks>llama_progress_callback</remarks> /// <remarks>llama_progress_callback</remarks>
public delegate bool LlamaProgressCallback(float progress, IntPtr ctx); public delegate bool LlamaProgressCallback(float progress, IntPtr ctx);

View File

@ -38,7 +38,7 @@ namespace LLama.Native
// as NET Framework 4.8 does not play nice with the LlamaProgressCallback type // as NET Framework 4.8 does not play nice with the LlamaProgressCallback type
public IntPtr progress_callback; public IntPtr progress_callback;
#else #else
public LlamaProgressCallback progress_callback; public LlamaProgressCallback? progress_callback;
#endif #endif
/// <summary> /// <summary>

View File

@ -120,8 +120,11 @@ namespace LLama.Native
if (!fs.CanRead) if (!fs.CanRead)
throw new InvalidOperationException($"Model file '{modelPath}' is not readable"); throw new InvalidOperationException($"Model file '{modelPath}' is not readable");
return llama_load_model_from_file(modelPath, lparams) var handle = llama_load_model_from_file(modelPath, lparams);
?? throw new LoadWeightsFailedException(modelPath); if (handle.IsInvalid)
throw new LoadWeightsFailedException(modelPath);
return handle;
} }
#region native API #region native API