From 00df7c151628f42b1412284838635c27cf9957a4 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 27 Apr 2024 02:52:41 +0100 Subject: [PATCH 1/3] - Added `LLamaWeights.LoadFromFileAsync`. - Async loading supports cancellation through a `CancellationToken`. If loading is cancelled an `OperationCanceledException` is thrown. If it fails for another reason a `LoadWeightsFailedException` is thrown. - Updated examples to use `LoadFromFileAsync` --- .../Examples/BatchedExecutorFork.cs | 2 +- .../Examples/BatchedExecutorGuidance.cs | 2 +- .../Examples/BatchedExecutorRewind.cs | 2 +- .../Examples/BatchedExecutorSaveAndLoad.cs | 2 +- LLama.Examples/Examples/ChatChineseGB2312.cs | 2 +- .../Examples/ChatSessionStripRoleName.cs | 4 +- .../Examples/ChatSessionWithHistory.cs | 2 +- .../Examples/ChatSessionWithRestart.cs | 4 +- .../Examples/ChatSessionWithRoleName.cs | 4 +- LLama.Examples/Examples/CodingAssistant.cs | 2 +- .../Examples/GrammarJsonResponse.cs | 4 +- .../Examples/InstructModeExecute.cs | 4 +- .../Examples/InteractiveModeExecute.cs | 2 +- .../Examples/LlavaInteractiveModeExecute.cs | 2 +- LLama.Examples/Examples/LoadAndSaveSession.cs | 2 +- LLama.Examples/Examples/LoadAndSaveState.cs | 2 +- LLama.Examples/Examples/SemanticKernelChat.cs | 2 +- .../Examples/SemanticKernelMemory.cs | 2 +- .../Examples/SemanticKernelPrompt.cs | 2 +- .../Examples/StatelessModeExecute.cs | 2 +- LLama.Examples/Examples/TalkToYourself.cs | 2 +- LLama/LLamaWeights.cs | 80 +++++++++++++++++++ LLama/Native/LLamaContextParams.cs | 2 + LLama/Native/LLamaModelParams.cs | 2 +- LLama/Native/SafeLlamaModelHandle.cs | 7 +- 25 files changed, 114 insertions(+), 29 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorFork.cs b/LLama.Examples/Examples/BatchedExecutorFork.cs index febba5c3..2c401822 100644 --- a/LLama.Examples/Examples/BatchedExecutorFork.cs +++ b/LLama.Examples/Examples/BatchedExecutorFork.cs @@ -19,7 +19,7 @@ public class BatchedExecutorFork string modelPath = UserSettings.GetModelPath(); var parameters = new ModelParams(modelPath); - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); diff --git a/LLama.Examples/Examples/BatchedExecutorGuidance.cs b/LLama.Examples/Examples/BatchedExecutorGuidance.cs index 6f3eceab..b006c88b 100644 --- a/LLama.Examples/Examples/BatchedExecutorGuidance.cs +++ b/LLama.Examples/Examples/BatchedExecutorGuidance.cs @@ -19,7 +19,7 @@ public class BatchedExecutorGuidance string modelPath = UserSettings.GetModelPath(); var parameters = new ModelParams(modelPath); - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim(); var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim(); diff --git a/LLama.Examples/Examples/BatchedExecutorRewind.cs b/LLama.Examples/Examples/BatchedExecutorRewind.cs index 8aae92af..938b3106 100644 --- a/LLama.Examples/Examples/BatchedExecutorRewind.cs +++ b/LLama.Examples/Examples/BatchedExecutorRewind.cs @@ -20,7 +20,7 @@ public class BatchedExecutorRewind string modelPath = UserSettings.GetModelPath(); var parameters = new ModelParams(modelPath); - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); diff --git a/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs b/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs index af0dea52..0ec903eb 100644 --- a/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs +++ b/LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs @@ -18,7 +18,7 @@ public class BatchedExecutorSaveAndLoad string modelPath = UserSettings.GetModelPath(); var parameters = new ModelParams(modelPath); - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs index a5db02cd..c59a522f 100644 --- a/LLama.Examples/Examples/ChatChineseGB2312.cs +++ b/LLama.Examples/Examples/ChatChineseGB2312.cs @@ -31,7 +31,7 @@ public class ChatChineseGB2312 GpuLayerCount = 5, Encoding = Encoding.UTF8 }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); var executor = new InteractiveExecutor(context); diff --git a/LLama.Examples/Examples/ChatSessionStripRoleName.cs b/LLama.Examples/Examples/ChatSessionStripRoleName.cs index ff0b369d..5469aa8f 100644 --- a/LLama.Examples/Examples/ChatSessionStripRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionStripRoleName.cs @@ -15,11 +15,11 @@ public class ChatSessionStripRoleName Seed = 1337, GpuLayerCount = 5 }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); var executor = new InteractiveExecutor(context); - var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json"); ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatSession session = new(executor, chatHistory); diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index da5e3ad0..af7d7eac 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -13,7 +13,7 @@ public class ChatSessionWithHistory Seed = 1337, GpuLayerCount = 5 }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); var executor = new InteractiveExecutor(context); diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs index 48754a81..c2bfb895 100644 --- a/LLama.Examples/Examples/ChatSessionWithRestart.cs +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -13,11 +13,11 @@ public class ChatSessionWithRestart Seed = 1337, GpuLayerCount = 5 }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); var executor = new InteractiveExecutor(context); - var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json"); ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatSession prototypeSession = await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory); diff --git a/LLama.Examples/Examples/ChatSessionWithRoleName.cs b/LLama.Examples/Examples/ChatSessionWithRoleName.cs index 08f7666b..4e2befd9 100644 --- a/LLama.Examples/Examples/ChatSessionWithRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionWithRoleName.cs @@ -13,11 +13,11 @@ public class ChatSessionWithRoleName Seed = 1337, GpuLayerCount = 5 }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); var executor = new InteractiveExecutor(context); - var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); + var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json"); ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatSession session = new(executor, chatHistory); diff --git a/LLama.Examples/Examples/CodingAssistant.cs b/LLama.Examples/Examples/CodingAssistant.cs index 808c3904..a2edf8be 100644 --- a/LLama.Examples/Examples/CodingAssistant.cs +++ b/LLama.Examples/Examples/CodingAssistant.cs @@ -29,7 +29,7 @@ { ContextSize = 4096 }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix, null); diff --git a/LLama.Examples/Examples/GrammarJsonResponse.cs b/LLama.Examples/Examples/GrammarJsonResponse.cs index 139bd4ac..a5bb5486 100644 --- a/LLama.Examples/Examples/GrammarJsonResponse.cs +++ b/LLama.Examples/Examples/GrammarJsonResponse.cs @@ -9,7 +9,7 @@ namespace LLama.Examples.Examples { string modelPath = UserSettings.GetModelPath(); - var gbnf = File.ReadAllText("Assets/json.gbnf").Trim(); + var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim(); var grammar = Grammar.Parse(gbnf, "root"); var parameters = new ModelParams(modelPath) @@ -17,7 +17,7 @@ namespace LLama.Examples.Examples Seed = 1337, GpuLayerCount = 5 }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); var ex = new StatelessExecutor(model, parameters); Console.ForegroundColor = ConsoleColor.Yellow; diff --git a/LLama.Examples/Examples/InstructModeExecute.cs b/LLama.Examples/Examples/InstructModeExecute.cs index 73b5da79..4f65dd23 100644 --- a/LLama.Examples/Examples/InstructModeExecute.cs +++ b/LLama.Examples/Examples/InstructModeExecute.cs @@ -9,14 +9,14 @@ namespace LLama.Examples.Examples { string modelPath = UserSettings.GetModelPath(); - var prompt = File.ReadAllText("Assets/dan.txt").Trim(); + var prompt = (await File.ReadAllTextAsync("Assets/dan.txt")).Trim(); var parameters = new ModelParams(modelPath) { Seed = 1337, GpuLayerCount = 5 }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); var executor = new InstructExecutor(context); diff --git a/LLama.Examples/Examples/InteractiveModeExecute.cs b/LLama.Examples/Examples/InteractiveModeExecute.cs index d7d364fb..15a9c94c 100644 --- a/LLama.Examples/Examples/InteractiveModeExecute.cs +++ b/LLama.Examples/Examples/InteractiveModeExecute.cs @@ -16,7 +16,7 @@ namespace LLama.Examples.Examples Seed = 1337, GpuLayerCount = 5 }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); var ex = new InteractiveExecutor(context); diff --git a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs index 112fe23f..170bab0c 100644 --- a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs +++ b/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs @@ -20,7 +20,7 @@ namespace LLama.Examples.Examples var parameters = new ModelParams(modelPath); - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); // Llava Init diff --git a/LLama.Examples/Examples/LoadAndSaveSession.cs b/LLama.Examples/Examples/LoadAndSaveSession.cs index d8a31017..68ed8aa3 100644 --- a/LLama.Examples/Examples/LoadAndSaveSession.cs +++ b/LLama.Examples/Examples/LoadAndSaveSession.cs @@ -15,7 +15,7 @@ namespace LLama.Examples.Examples Seed = 1337, GpuLayerCount = 5 }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); var ex = new InteractiveExecutor(context); diff --git a/LLama.Examples/Examples/LoadAndSaveState.cs b/LLama.Examples/Examples/LoadAndSaveState.cs index 9cf93e7f..0fef49f1 100644 --- a/LLama.Examples/Examples/LoadAndSaveState.cs +++ b/LLama.Examples/Examples/LoadAndSaveState.cs @@ -16,7 +16,7 @@ namespace LLama.Examples.Examples Seed = 1337, GpuLayerCount = 5 }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); using var context = model.CreateContext(parameters); var ex = new InteractiveExecutor(context); diff --git a/LLama.Examples/Examples/SemanticKernelChat.cs b/LLama.Examples/Examples/SemanticKernelChat.cs index 258ca86b..2631cc9b 100644 --- a/LLama.Examples/Examples/SemanticKernelChat.cs +++ b/LLama.Examples/Examples/SemanticKernelChat.cs @@ -16,7 +16,7 @@ namespace LLama.Examples.Examples // Load weights into memory var parameters = new ModelParams(modelPath); - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); var ex = new StatelessExecutor(model, parameters); var chatGPT = new LLamaSharpChatCompletion(ex); diff --git a/LLama.Examples/Examples/SemanticKernelMemory.cs b/LLama.Examples/Examples/SemanticKernelMemory.cs index 46c9a17d..3fad5ae0 100644 --- a/LLama.Examples/Examples/SemanticKernelMemory.cs +++ b/LLama.Examples/Examples/SemanticKernelMemory.cs @@ -23,7 +23,7 @@ namespace LLama.Examples.Examples Embeddings = true }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); var embedding = new LLamaEmbedder(model, parameters); Console.WriteLine("===================================================="); diff --git a/LLama.Examples/Examples/SemanticKernelPrompt.cs b/LLama.Examples/Examples/SemanticKernelPrompt.cs index fdf58b3a..63e848cb 100644 --- a/LLama.Examples/Examples/SemanticKernelPrompt.cs +++ b/LLama.Examples/Examples/SemanticKernelPrompt.cs @@ -19,7 +19,7 @@ namespace LLama.Examples.Examples // Load weights into memory var parameters = new ModelParams(modelPath); - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); var ex = new StatelessExecutor(model, parameters); var builder = Kernel.CreateBuilder(); diff --git a/LLama.Examples/Examples/StatelessModeExecute.cs b/LLama.Examples/Examples/StatelessModeExecute.cs index 4d2edd19..806616e7 100644 --- a/LLama.Examples/Examples/StatelessModeExecute.cs +++ b/LLama.Examples/Examples/StatelessModeExecute.cs @@ -14,7 +14,7 @@ namespace LLama.Examples.Examples Seed = 1337, GpuLayerCount = 5 }; - using var model = LLamaWeights.LoadFromFile(parameters); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); var ex = new StatelessExecutor(model, parameters); Console.ForegroundColor = ConsoleColor.Yellow; diff --git a/LLama.Examples/Examples/TalkToYourself.cs b/LLama.Examples/Examples/TalkToYourself.cs index bf72423f..f888209a 100644 --- a/LLama.Examples/Examples/TalkToYourself.cs +++ b/LLama.Examples/Examples/TalkToYourself.cs @@ -12,7 +12,7 @@ namespace LLama.Examples.Examples // Load weights into memory var @params = new ModelParams(modelPath); - using var weights = LLamaWeights.LoadFromFile(@params); + using var weights = await LLamaWeights.LoadFromFileAsync(@params); // Create 2 contexts sharing the same weights using var aliceCtx = weights.CreateContext(@params); diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 2d8ea4d9..ad041aac 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -1,7 +1,10 @@ using System; using System.Collections.Generic; using System.Text; +using System.Threading; +using System.Threading.Tasks; using LLama.Abstractions; +using LLama.Exceptions; using LLama.Extensions; using LLama.Native; using Microsoft.Extensions.Logging; @@ -84,6 +87,83 @@ namespace LLama return new LLamaWeights(weights); } + /// + /// Load weights into memory + /// + /// Parameters to use to load the model + /// A cancellation token that can interrupt model loading + /// + /// Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled. + /// Thrown if the cancellation token is cancelled. + public static async Task LoadFromFileAsync(IModelParams @params, CancellationToken token = default) + { + // don't touch the @params object inside the task, it might be changed + // externally! Save a copy of everything that we need later. + var modelPath = @params.ModelPath; + var loraBase = @params.LoraBase; + var loraAdapters = @params.LoraAdapters.ToArray(); + + using (@params.ToLlamaModelParams(out var lparams)) + { +#if !NETSTANDARD2_0 + // Overwrite the progress callback with one which polls the cancellation token + //if (token.CanBeCanceled) + { + var internalCallback = lparams.progress_callback; + lparams.progress_callback = (progress, ctx) => + { + // If the user set a call in the model params, first call that and see if we should cancel + if (internalCallback != null && !internalCallback(progress, ctx)) + return false; + + // Check the cancellation token + if (token.IsCancellationRequested) + return false; + + return true; + }; + } +#endif + + var model = await Task.Run(() => + { + try + { + var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams); + foreach (var adapter in loraAdapters) + { + // Interrupt applying LoRAs if the token is cancelled + if (token.IsCancellationRequested) + { + weights.Dispose(); + token.ThrowIfCancellationRequested(); + } + + // Don't apply invalid adapters + if (string.IsNullOrEmpty(adapter.Path)) + continue; + if (adapter.Scale <= 0) + continue; + + weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase); + } + + return new LLamaWeights(weights); + } + catch (LoadWeightsFailedException) + { + // Convert a LoadWeightsFailedException into a cancellation exception if possible. + token.ThrowIfCancellationRequested(); + + // Ok the weights failed to load for some reason other than cancellation. + throw; + } + }, token); + + return model; + } + } + /// public void Dispose() { diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index 8e3d7f74..1ea52e6b 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -8,6 +8,8 @@ namespace LLama.Native /// /// /// + /// If the provided progress_callback returns true, model loading continues. + /// If it returns false, model loading is immediately aborted. /// llama_progress_callback public delegate bool LlamaProgressCallback(float progress, IntPtr ctx); diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs index 923b042c..6fca41fc 100644 --- a/LLama/Native/LLamaModelParams.cs +++ b/LLama/Native/LLamaModelParams.cs @@ -38,7 +38,7 @@ namespace LLama.Native // as NET Framework 4.8 does not play nice with the LlamaProgressCallback type public IntPtr progress_callback; #else - public LlamaProgressCallback progress_callback; + public LlamaProgressCallback? progress_callback; #endif /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 23c1f767..2758c050 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -120,8 +120,11 @@ namespace LLama.Native if (!fs.CanRead) throw new InvalidOperationException($"Model file '{modelPath}' is not readable"); - return llama_load_model_from_file(modelPath, lparams) - ?? throw new LoadWeightsFailedException(modelPath); + var handle = llama_load_model_from_file(modelPath, lparams); + if (handle.IsInvalid) + throw new LoadWeightsFailedException(modelPath); + + return handle; } #region native API From 9867b4c85d4ff242be38fadbc103f64796b479ff Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 27 Apr 2024 02:55:35 +0100 Subject: [PATCH 2/3] Only setting callback if the token can be cancelled. --- LLama/LLamaWeights.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index ad041aac..e37de1e9 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -107,7 +107,7 @@ namespace LLama { #if !NETSTANDARD2_0 // Overwrite the progress callback with one which polls the cancellation token - //if (token.CanBeCanceled) + if (token.CanBeCanceled) { var internalCallback = lparams.progress_callback; lparams.progress_callback = (progress, ctx) => From 1ec0fee5ba04480b789aaf90440056d33e1e68d2 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 27 Apr 2024 15:04:54 +0100 Subject: [PATCH 3/3] Added optional `IProgress` parameter to `LoadFromFileAsync` --- LLama/LLamaWeights.cs | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index e37de1e9..ce712b72 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -92,10 +92,11 @@ namespace LLama /// /// Parameters to use to load the model /// A cancellation token that can interrupt model loading + /// Receives progress updates as the model loads (0 to 1) /// /// Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled. /// Thrown if the cancellation token is cancelled. - public static async Task LoadFromFileAsync(IModelParams @params, CancellationToken token = default) + public static async Task LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress? progressReporter = null) { // don't touch the @params object inside the task, it might be changed // externally! Save a copy of everything that we need later. @@ -103,16 +104,25 @@ namespace LLama var loraBase = @params.LoraBase; var loraAdapters = @params.LoraAdapters.ToArray(); + // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a + // slightly smaller range to allow some space for reporting LoRA loading too. + var modelLoadProgressRange = 1f; + if (loraAdapters.Length > 0) + modelLoadProgressRange = 0.9f; + using (@params.ToLlamaModelParams(out var lparams)) { #if !NETSTANDARD2_0 - // Overwrite the progress callback with one which polls the cancellation token - if (token.CanBeCanceled) + // Overwrite the progress callback with one which polls the cancellation token and updates the progress object + if (token.CanBeCanceled || progressReporter != null) { var internalCallback = lparams.progress_callback; lparams.progress_callback = (progress, ctx) => { - // If the user set a call in the model params, first call that and see if we should cancel + // Update the progress reporter (remapping the value into the smaller range). + progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange); + + // If the user set a callback in the model params, call that and see if we should cancel if (internalCallback != null && !internalCallback(progress, ctx)) return false; @@ -129,8 +139,11 @@ namespace LLama { try { + // Load the model var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams); - foreach (var adapter in loraAdapters) + + // Apply the LoRA adapters + for (var i = 0; i < loraAdapters.Length; i++) { // Interrupt applying LoRAs if the token is cancelled if (token.IsCancellationRequested) @@ -140,14 +153,22 @@ namespace LLama } // Don't apply invalid adapters + var adapter = loraAdapters[i]; if (string.IsNullOrEmpty(adapter.Path)) continue; if (adapter.Scale <= 0) continue; weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase); + + // Report progress. Model loading reported progress from 0 -> 0.9, use + // the last 0.1 to represent all of the LoRA adapters being applied. + progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1)); } + // Update progress reporter to indicate completion + progressReporter?.Report(1); + return new LLamaWeights(weights); } catch (LoadWeightsFailedException)