- Added `LoadFromFileAsync` method for `LLavaWeights`

- Fixed checking for invalid handles in `clip_model_load`
This commit is contained in:
Martin Evans 2024-04-27 23:31:07 +01:00
parent 84bb5a36ab
commit 377ebf3664
3 changed files with 23 additions and 7 deletions

View File

@ -24,7 +24,7 @@ namespace LLama.Examples.Examples
using var context = model.CreateContext(parameters); using var context = model.CreateContext(parameters);
// Llava Init // Llava Init
using var clipModel = LLavaWeights.LoadFromFile(multiModalProj); using var clipModel = await LLavaWeights.LoadFromFileAsync(multiModalProj);
var ex = new InteractiveExecutor(context, clipModel); var ex = new InteractiveExecutor(context, clipModel);

View File

@ -1,5 +1,7 @@
using System; using System;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native; using LLama.Native;
namespace LLama; namespace LLama;
@ -15,7 +17,7 @@ public sealed class LLavaWeights : IDisposable
/// <remarks>Be careful how you use this!</remarks> /// <remarks>Be careful how you use this!</remarks>
public SafeLlavaModelHandle NativeHandle { get; } public SafeLlavaModelHandle NativeHandle { get; }
internal LLavaWeights(SafeLlavaModelHandle weights) private LLavaWeights(SafeLlavaModelHandle weights)
{ {
NativeHandle = weights; NativeHandle = weights;
} }
@ -31,6 +33,17 @@ public sealed class LLavaWeights : IDisposable
return new LLavaWeights(weights); return new LLavaWeights(weights);
} }
/// <summary>
/// Load weights into memory
/// </summary>
/// <param name="mmProject">path to the "mmproj" model file</param>
/// <param name="token"></param>
/// <returns></returns>
public static Task<LLavaWeights> LoadFromFileAsync(string mmProject, CancellationToken token = default)
{
return Task.Run(() => LoadFromFile(mmProject), token);
}
/// <summary> /// <summary>
/// Create the Image Embeddings from the bytes of an image. /// Create the Image Embeddings from the bytes of an image.
/// </summary> /// </summary>

View File

@ -39,8 +39,11 @@ namespace LLama.Native
if (!fs.CanRead) if (!fs.CanRead)
throw new InvalidOperationException($"Llava MMP Model file '{modelPath}' is not readable"); throw new InvalidOperationException($"Llava MMP Model file '{modelPath}' is not readable");
return clip_model_load(modelPath, verbosity) var handle = clip_model_load(modelPath, verbosity);
?? throw new LoadWeightsFailedException(modelPath); if (handle.IsInvalid)
throw new LoadWeightsFailedException(modelPath);
return handle;
} }
/// <summary> /// <summary>