- Added `LLamaWeights.LoadFromFileAsync`.
- Async loading supports cancellation through a `CancellationToken`. If loading is cancelled an `OperationCanceledException` is thrown. If it fails for another reason a `LoadWeightsFailedException` is thrown. - Updated examples to use `LoadFromFileAsync`
This commit is contained in:
parent
b47ed9258f
commit
00df7c1516
|
@ -19,7 +19,7 @@ public class BatchedExecutorFork
|
|||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
|
||||
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ public class BatchedExecutorGuidance
|
|||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
|
||||
var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
|
||||
var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
|
||||
|
|
|
@ -20,7 +20,7 @@ public class BatchedExecutorRewind
|
|||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
|
||||
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ public class BatchedExecutorSaveAndLoad
|
|||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
|
||||
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ public class ChatChineseGB2312
|
|||
GpuLayerCount = 5,
|
||||
Encoding = Encoding.UTF8
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InteractiveExecutor(context);
|
||||
|
||||
|
|
|
@ -15,11 +15,11 @@ public class ChatSessionStripRoleName
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InteractiveExecutor(context);
|
||||
|
||||
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
|
||||
var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
|
||||
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
|
||||
|
||||
ChatSession session = new(executor, chatHistory);
|
||||
|
|
|
@ -13,7 +13,7 @@ public class ChatSessionWithHistory
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InteractiveExecutor(context);
|
||||
|
||||
|
|
|
@ -13,11 +13,11 @@ public class ChatSessionWithRestart
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InteractiveExecutor(context);
|
||||
|
||||
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
|
||||
var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
|
||||
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
|
||||
ChatSession prototypeSession =
|
||||
await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory);
|
||||
|
|
|
@ -13,11 +13,11 @@ public class ChatSessionWithRoleName
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InteractiveExecutor(context);
|
||||
|
||||
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
|
||||
var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
|
||||
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
|
||||
|
||||
ChatSession session = new(executor, chatHistory);
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
{
|
||||
ContextSize = 4096
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix, null);
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ namespace LLama.Examples.Examples
|
|||
{
|
||||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var gbnf = File.ReadAllText("Assets/json.gbnf").Trim();
|
||||
var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim();
|
||||
var grammar = Grammar.Parse(gbnf, "root");
|
||||
|
||||
var parameters = new ModelParams(modelPath)
|
||||
|
@ -17,7 +17,7 @@ namespace LLama.Examples.Examples
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
var ex = new StatelessExecutor(model, parameters);
|
||||
|
||||
Console.ForegroundColor = ConsoleColor.Yellow;
|
||||
|
|
|
@ -9,14 +9,14 @@ namespace LLama.Examples.Examples
|
|||
{
|
||||
string modelPath = UserSettings.GetModelPath();
|
||||
|
||||
var prompt = File.ReadAllText("Assets/dan.txt").Trim();
|
||||
var prompt = (await File.ReadAllTextAsync("Assets/dan.txt")).Trim();
|
||||
|
||||
var parameters = new ModelParams(modelPath)
|
||||
{
|
||||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var executor = new InstructExecutor(context);
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace LLama.Examples.Examples
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var ex = new InteractiveExecutor(context);
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace LLama.Examples.Examples
|
|||
|
||||
var parameters = new ModelParams(modelPath);
|
||||
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
|
||||
// Llava Init
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace LLama.Examples.Examples
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var ex = new InteractiveExecutor(context);
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace LLama.Examples.Examples
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
using var context = model.CreateContext(parameters);
|
||||
var ex = new InteractiveExecutor(context);
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ namespace LLama.Examples.Examples
|
|||
|
||||
// Load weights into memory
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
var ex = new StatelessExecutor(model, parameters);
|
||||
|
||||
var chatGPT = new LLamaSharpChatCompletion(ex);
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace LLama.Examples.Examples
|
|||
Embeddings = true
|
||||
};
|
||||
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
var embedding = new LLamaEmbedder(model, parameters);
|
||||
|
||||
Console.WriteLine("====================================================");
|
||||
|
|
|
@ -19,7 +19,7 @@ namespace LLama.Examples.Examples
|
|||
|
||||
// Load weights into memory
|
||||
var parameters = new ModelParams(modelPath);
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
var ex = new StatelessExecutor(model, parameters);
|
||||
|
||||
var builder = Kernel.CreateBuilder();
|
||||
|
|
|
@ -14,7 +14,7 @@ namespace LLama.Examples.Examples
|
|||
Seed = 1337,
|
||||
GpuLayerCount = 5
|
||||
};
|
||||
using var model = LLamaWeights.LoadFromFile(parameters);
|
||||
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
|
||||
var ex = new StatelessExecutor(model, parameters);
|
||||
|
||||
Console.ForegroundColor = ConsoleColor.Yellow;
|
||||
|
|
|
@ -12,7 +12,7 @@ namespace LLama.Examples.Examples
|
|||
|
||||
// Load weights into memory
|
||||
var @params = new ModelParams(modelPath);
|
||||
using var weights = LLamaWeights.LoadFromFile(@params);
|
||||
using var weights = await LLamaWeights.LoadFromFileAsync(@params);
|
||||
|
||||
// Create 2 contexts sharing the same weights
|
||||
using var aliceCtx = weights.CreateContext(@params);
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using LLama.Abstractions;
|
||||
using LLama.Exceptions;
|
||||
using LLama.Extensions;
|
||||
using LLama.Native;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
@ -84,6 +87,83 @@ namespace LLama
|
|||
return new LLamaWeights(weights);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Load weights into memory
|
||||
/// </summary>
|
||||
/// <param name="params">Parameters to use to load the model</param>
|
||||
/// <param name="token">A cancellation token that can interrupt model loading</param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="LoadWeightsFailedException">Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.</exception>
|
||||
/// <exception cref="OperationCanceledException">Thrown if the cancellation token is cancelled.</exception>
|
||||
public static async Task<LLamaWeights> LoadFromFileAsync(IModelParams @params, CancellationToken token = default)
|
||||
{
|
||||
// don't touch the @params object inside the task, it might be changed
|
||||
// externally! Save a copy of everything that we need later.
|
||||
var modelPath = @params.ModelPath;
|
||||
var loraBase = @params.LoraBase;
|
||||
var loraAdapters = @params.LoraAdapters.ToArray();
|
||||
|
||||
using (@params.ToLlamaModelParams(out var lparams))
|
||||
{
|
||||
#if !NETSTANDARD2_0
|
||||
// Overwrite the progress callback with one which polls the cancellation token
|
||||
//if (token.CanBeCanceled)
|
||||
{
|
||||
var internalCallback = lparams.progress_callback;
|
||||
lparams.progress_callback = (progress, ctx) =>
|
||||
{
|
||||
// If the user set a call in the model params, first call that and see if we should cancel
|
||||
if (internalCallback != null && !internalCallback(progress, ctx))
|
||||
return false;
|
||||
|
||||
// Check the cancellation token
|
||||
if (token.IsCancellationRequested)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
}
|
||||
#endif
|
||||
|
||||
var model = await Task.Run(() =>
|
||||
{
|
||||
try
|
||||
{
|
||||
var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams);
|
||||
foreach (var adapter in loraAdapters)
|
||||
{
|
||||
// Interrupt applying LoRAs if the token is cancelled
|
||||
if (token.IsCancellationRequested)
|
||||
{
|
||||
weights.Dispose();
|
||||
token.ThrowIfCancellationRequested();
|
||||
}
|
||||
|
||||
// Don't apply invalid adapters
|
||||
if (string.IsNullOrEmpty(adapter.Path))
|
||||
continue;
|
||||
if (adapter.Scale <= 0)
|
||||
continue;
|
||||
|
||||
weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase);
|
||||
}
|
||||
|
||||
return new LLamaWeights(weights);
|
||||
}
|
||||
catch (LoadWeightsFailedException)
|
||||
{
|
||||
// Convert a LoadWeightsFailedException into a cancellation exception if possible.
|
||||
token.ThrowIfCancellationRequested();
|
||||
|
||||
// Ok the weights failed to load for some reason other than cancellation.
|
||||
throw;
|
||||
}
|
||||
}, token);
|
||||
|
||||
return model;
|
||||
}
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{
|
||||
|
|
|
@ -8,6 +8,8 @@ namespace LLama.Native
|
|||
/// </summary>
|
||||
/// <param name="progress"></param>
|
||||
/// <param name="ctx"></param>
|
||||
/// <returns>If the provided progress_callback returns true, model loading continues.
|
||||
/// If it returns false, model loading is immediately aborted.</returns>
|
||||
/// <remarks>llama_progress_callback</remarks>
|
||||
public delegate bool LlamaProgressCallback(float progress, IntPtr ctx);
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ namespace LLama.Native
|
|||
// as NET Framework 4.8 does not play nice with the LlamaProgressCallback type
|
||||
public IntPtr progress_callback;
|
||||
#else
|
||||
public LlamaProgressCallback progress_callback;
|
||||
public LlamaProgressCallback? progress_callback;
|
||||
#endif
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -120,8 +120,11 @@ namespace LLama.Native
|
|||
if (!fs.CanRead)
|
||||
throw new InvalidOperationException($"Model file '{modelPath}' is not readable");
|
||||
|
||||
return llama_load_model_from_file(modelPath, lparams)
|
||||
?? throw new LoadWeightsFailedException(modelPath);
|
||||
var handle = llama_load_model_from_file(modelPath, lparams);
|
||||
if (handle.IsInvalid)
|
||||
throw new LoadWeightsFailedException(modelPath);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
#region native API
|
||||
|
|
Loading…
Reference in New Issue