Merge pull request #702 from martindevans/interruptible_async_model_load
Interruptible Async Model Loading With Progress Monitoring
This commit is contained in:
commit
84bb5a36ab
|
@ -19,7 +19,7 @@ public class BatchedExecutorFork
|
|||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
|
||||
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ public class BatchedExecutorGuidance
|
|||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
|
||||
var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
|
||||
var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
|
||||
|
|
|
@ -20,7 +20,7 @@ public class BatchedExecutorRewind
|
|||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
|
||||
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ public class BatchedExecutorSaveAndLoad
|
|||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
|
||||
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ public class ChatChineseGB2312
|
|||
GpuLayerCount = 5,
|
||||
Encoding = Encoding.UTF8
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InteractiveExecutor(context);
|
||||
|
||||
|
|
|
@ -15,11 +15,11 @@ public class ChatSessionStripRoleName
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InteractiveExecutor(context);
|
||||
|
||||
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
|
||||
var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
|
||||
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
|
||||
|
||||
ChatSession session = new(executor, chatHistory);
|
||||
|
|
|
@ -13,7 +13,7 @@ public class ChatSessionWithHistory
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InteractiveExecutor(context);
|
||||
|
||||
|
|
|
@ -13,11 +13,11 @@ public class ChatSessionWithRestart
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InteractiveExecutor(context);
|
||||
|
||||
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
|
||||
var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
|
||||
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
|
||||
ChatSession prototypeSession =
|
||||
await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory);
|
||||
|
|
|
@ -13,11 +13,11 @@ public class ChatSessionWithRoleName
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InteractiveExecutor(context);
|
||||
|
||||
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
|
||||
var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
|
||||
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
|
||||
|
||||
ChatSession session = new(executor, chatHistory);
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
{
|
||||
ContextSize = 4096
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix, null);
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ namespace LLama.Examples.Examples
|
|||
{
|
||||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var gbnf = File.ReadAllText("Assets/json.gbnf").Trim();
|
||||
var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim();
|
||||
var grammar = Grammar.Parse(gbnf, "root");
|
||||
|
||||
var parameters = new ModelParams(modelPath)
|
||||
|
@ -17,7 +17,7 @@ namespace LLama.Examples.Examples
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
var ex = new StatelessExecutor(model, parameters);
|
||||
|
||||
Console.ForegroundColor = ConsoleColor.Yellow;
|
||||
|
|
|
@ -9,14 +9,14 @@ namespace LLama.Examples.Examples
|
|||
{
|
||||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var prompt = File.ReadAllText("Assets/dan.txt").Trim();
|
||||
var prompt = (await File.ReadAllTextAsync("Assets/dan.txt")).Trim();
|
||||
|
||||
var parameters = new ModelParams(modelPath)
|
||||
{
|
||||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InstructExecutor(context);
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace LLama.Examples.Examples
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var ex = new InteractiveExecutor(context);
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace LLama.Examples.Examples
|
|||
|
||||
var parameters = new ModelParams(modelPath);
|
||||
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
|
||||
// Llava Init
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace LLama.Examples.Examples
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var ex = new InteractiveExecutor(context);
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace LLama.Examples.Examples
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var ex = new InteractiveExecutor(context);
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace LLama.Examples.Examples
|
|||
|
||||
// Load weights into memory
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
var ex = new StatelessExecutor(model, parameters);
|
||||
|
||||
var chatGPT = new LLamaSharpChatCompletion(ex);
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace LLama.Examples.Examples
|
|||
Embeddings = true
|
||||
};
|
||||
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
var embedding = new LLamaEmbedder(model, parameters);
|
||||
|
||||
Console.WriteLine("====================================================");
|
||||
|
|
|
@ -19,7 +19,7 @@ namespace LLama.Examples.Examples
|
|||
|
||||
// Load weights into memory
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
var ex = new StatelessExecutor(model, parameters);
|
||||
|
||||
var builder = Kernel.CreateBuilder();
|
||||
|
|
|
@ -14,7 +14,7 @@ namespace LLama.Examples.Examples
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
var ex = new StatelessExecutor(model, parameters);
|
||||
|
||||
Console.ForegroundColor = ConsoleColor.Yellow;
|
||||
|
|
|
@ -12,7 +12,7 @@ namespace LLama.Examples.Examples
|
|||
|
||||
// Load weights into memory
|
||||
var @params = new ModelParams(modelPath);
|
||||
using var weights = LLamaWeights.LoadFromFile(@params);
|
||||
using var weights = await LLamaWeights.LoadFromFileAsync(@params);
|
||||
|
||||
// Create 2 contexts sharing the same weights
|
||||
using var aliceCtx = weights.CreateContext(@params);
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using LLama.Abstractions;
|
||||
using LLama.Exceptions;
|
||||
using LLama.Extensions;
|
||||
using LLama.Native;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
@ -84,6 +87,104 @@ namespace LLama
|
|||
return new LLamaWeights(weights);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load weights into memory
|
||||
/// </summary>
|
||||
/// <param name="params">Parameters to use to load the model</param>
|
||||
/// <param name="token">A cancellation token that can interrupt model loading</param>
|
||||
/// <param name="progressReporter">Receives progress updates as the model loads (0 to 1)</param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="LoadWeightsFailedException">Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.</exception>
|
||||
/// <exception cref="OperationCanceledException">Thrown if the cancellation token is cancelled.</exception>
|
||||
public static async Task<LLamaWeights> LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress<float>? progressReporter = null)
|
||||
{
|
||||
// don't touch the @params object inside the task, it might be changed
|
||||
// externally! Save a copy of everything that we need later.
|
||||
var modelPath = @params.ModelPath;
|
||||
var loraBase = @params.LoraBase;
|
||||
var loraAdapters = @params.LoraAdapters.ToArray();
|
||||
|
||||
// Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a
|
||||
// slightly smaller range to allow some space for reporting LoRA loading too.
|
||||
var modelLoadProgressRange = 1f;
|
||||
if (loraAdapters.Length > 0)
|
||||
modelLoadProgressRange = 0.9f;
|
||||
|
||||
using (@params.ToLlamaModelParams(out var lparams))
|
||||
{
|
||||
#if !NETSTANDARD2_0
|
||||
// Overwrite the progress callback with one which polls the cancellation token and updates the progress object
|
||||
if (token.CanBeCanceled || progressReporter != null)
|
||||
{
|
||||
var internalCallback = lparams.progress_callback;
|
||||
lparams.progress_callback = (progress, ctx) =>
|
||||
{
|
||||
// Update the progress reporter (remapping the value into the smaller range).
|
||||
progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange);
|
||||
|
||||
// If the user set a callback in the model params, call that and see if we should cancel
|
||||
if (internalCallback != null && !internalCallback(progress, ctx))
|
||||
return false;
|
||||
|
||||
// Check the cancellation token
|
||||
if (token.IsCancellationRequested)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
}
|
||||
#endif
|
||||
|
||||
var model = await Task.Run(() =>
|
||||
{
|
||||
try
|
||||
{
|
||||
// Load the model
|
||||
var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams);
|
||||
|
||||
// Apply the LoRA adapters
|
||||
for (var i = 0; i < loraAdapters.Length; i++)
|
||||
{
|
||||
// Interrupt applying LoRAs if the token is cancelled
|
||||
if (token.IsCancellationRequested)
|
||||
{
|
||||
weights.Dispose();
|
||||
token.ThrowIfCancellationRequested();
|
||||
}
|
||||
|
||||
// Don't apply invalid adapters
|
||||
var adapter = loraAdapters[i];
|
||||
if (string.IsNullOrEmpty(adapter.Path))
|
||||
continue;
|
||||
if (adapter.Scale <= 0)
|
||||
continue;
|
||||
|
||||
weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase);
|
||||
|
||||
// Report progress. Model loading reported progress from 0 -> 0.9, use
|
||||
// the last 0.1 to represent all of the LoRA adapters being applied.
|
||||
progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1));
|
||||
}
|
||||
|
||||
// Update progress reporter to indicate completion
|
||||
progressReporter?.Report(1);
|
||||
|
||||
return new LLamaWeights(weights);
|
||||
}
|
||||
catch (LoadWeightsFailedException)
|
||||
{
|
||||
// Convert a LoadWeightsFailedException into a cancellation exception if possible.
|
||||
token.ThrowIfCancellationRequested();
|
||||
|
||||
// Ok the weights failed to load for some reason other than cancellation.
|
||||
throw;
|
||||
}
|
||||
}, token);
|
||||
|
||||
return model;
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
|
|
|
@ -8,6 +8,8 @@ namespace LLama.Native
|
|||
/// </summary>
|
||||
/// <param name="progress"></param>
|
||||
/// <param name="ctx"></param>
|
||||
/// <returns>If the provided progress_callback returns true, model loading continues.
|
||||
/// If it returns false, model loading is immediately aborted.</returns>
|
||||
/// <remarks>llama_progress_callback</remarks>
|
||||
public delegate bool LlamaProgressCallback(float progress, IntPtr ctx);
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ namespace LLama.Native
|
|||
// as NET Framework 4.8 does not play nice with the LlamaProgressCallback type
|
||||
public IntPtr progress_callback;
|
||||
#else
|
||||
public LlamaProgressCallback progress_callback;
|
||||
public LlamaProgressCallback? progress_callback;
|
||||
#endif
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -120,8 +120,11 @@ namespace LLama.Native
|
|||
if (!fs.CanRead)
|
||||
throw new InvalidOperationException($"Model file '{modelPath}' is not readable");
|
||||
|
||||
return llama_load_model_from_file(modelPath, lparams)
|
||||
?? throw new LoadWeightsFailedException(modelPath);
|
||||
var handle = llama_load_model_from_file(modelPath, lparams);
|
||||
if (handle.IsInvalid)
|
||||
throw new LoadWeightsFailedException(modelPath);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
#region native API
|
||||
|
|
Loading…
Reference in New Issue