`SetDllImportResolver` based loading (#603)

- Modified library loading to be based on `SetDllImportResolver`. This replaces the built in loading system and ensures there can't be two libraries loaded at once.
 - llava and llama are loaded separately, as needed.
 - All the previous loading logic is still used, within the `SetDllImportResolver`
 - Split out CUDA, AVX and MacOS paths to separate helper methods.
 - `Description` now specifies if it is for `llama` or `llava`
This commit is contained in:
Martin Evans 2024-03-17 19:54:20 +00:00 committed by GitHub
parent 6ddd45baa3
commit 024787225b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 305 additions and 167 deletions

View File

@ -16,11 +16,14 @@ AnsiConsole.MarkupLineInterpolated(
""");
// Configure native library to use
NativeLibraryConfig
.Instance
.WithCuda()
.WithLogs(LLamaLogLevel.Warning);
.WithLogs(LLamaLogLevel.Info);
// Calling this method forces loading to occur now.
NativeApi.llama_empty_call();
await ExampleRunner.Run();
await ExampleRunner.Run();

View File

@ -1,11 +1,9 @@
using LLama.Exceptions;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Runtime.InteropServices;
using System.Text.Json;
using System.Collections.Generic;
namespace LLama.Native
{
@ -13,9 +11,14 @@ namespace LLama.Native
{
static NativeApi()
{
// Try to load a preferred library, based on CPU feature detection
TryLoadLibrary();
// Overwrite the Dll import resolver for this assembly. The resolver gets
// called by the runtime every time that a call into a DLL is required. The
// resolver returns the loaded DLL handle. This allows us to take control of
// which llama.dll is used.
SetDllImportResolver();
// Immediately make a call which requires loading the llama DLL. This method call
// can't fail unless the DLL hasn't been loaded.
try
{
llama_empty_call();
@ -30,39 +33,97 @@ namespace LLama.Native
"4. Try to compile llama.cpp yourself to generate a libllama library, then use `LLama.Native.NativeLibraryConfig.WithLibrary` " +
"to specify it at the very beginning of your code. For more informations about compilation, please refer to LLamaSharp repo on github.\n");
}
// Init llama.cpp backend
llama_backend_init();
}
private static void Log(string message, LogLevel level)
#if NET5_0_OR_GREATER
private static IntPtr _loadedLlamaHandle;
private static IntPtr _loadedLlavaSharedHandle;
#endif
private static void SetDllImportResolver()
{
// NativeLibrary is not available on older runtimes. We'll have to depend on
// the normal runtime dll resolution there.
#if NET5_0_OR_GREATER
NativeLibrary.SetDllImportResolver(typeof(NativeApi).Assembly, (name, _, _) =>
{
if (name == "llama")
{
// If we've already loaded llama return the handle that was loaded last time.
if (_loadedLlamaHandle != IntPtr.Zero)
return _loadedLlamaHandle;
// Try to load a preferred library, based on CPU feature detection
_loadedLlamaHandle = TryLoadLibraries(LibraryName.Llama);
return _loadedLlamaHandle;
}
if (name == "llava_shared")
{
// If we've already loaded llava return the handle that was loaded last time.
if (_loadedLlavaSharedHandle != IntPtr.Zero)
return _loadedLlavaSharedHandle;
// Try to load a preferred library, based on CPU feature detection
_loadedLlavaSharedHandle = TryLoadLibraries(LibraryName.LlavaShared);
return _loadedLlavaSharedHandle;
}
// Return null pointer to indicate that nothing was loaded.
return IntPtr.Zero;
});
#endif
}
private static void Log(string message, LLamaLogLevel level)
{
if (!enableLogging)
return;
if ((int)level < (int)logLevel)
if ((int)level > (int)logLevel)
return;
ConsoleColor color;
string levelPrefix;
if (level == LogLevel.Information)
var fg = Console.ForegroundColor;
var bg = Console.BackgroundColor;
try
{
color = ConsoleColor.Green;
levelPrefix = "[Info]";
ConsoleColor color;
string levelPrefix;
if (level == LLamaLogLevel.Debug)
{
color = ConsoleColor.Cyan;
levelPrefix = "[Debug]";
}
else if (level == LLamaLogLevel.Info)
{
color = ConsoleColor.Green;
levelPrefix = "[Info]";
}
else if (level == LLamaLogLevel.Error)
{
color = ConsoleColor.Red;
levelPrefix = "[Error]";
}
else
{
color = ConsoleColor.Yellow;
levelPrefix = "[UNK]";
}
Console.ForegroundColor = color;
Console.WriteLine($"{loggingPrefix} {levelPrefix} {message}");
}
else if (level == LogLevel.Error)
finally
{
color = ConsoleColor.Red;
levelPrefix = "[Error]";
Console.ForegroundColor = fg;
Console.BackgroundColor = bg;
}
else
{
color = ConsoleColor.Yellow;
levelPrefix = "[Error]";
}
Console.ForegroundColor = color;
Console.WriteLine($"{loggingPrefix} {levelPrefix} {message}");
Console.ResetColor();
}
#region CUDA version
private static int GetCudaMajorVersion()
{
string? cudaPath;
@ -131,65 +192,33 @@ namespace LLama.Native
return string.Empty;
}
}
#endregion
#if NET6_0_OR_GREATER
private static string GetAvxLibraryPath(NativeLibraryConfig.AvxLevel avxLevel, string prefix, string suffix, string libraryNamePrefix)
private static IEnumerable<string> GetLibraryTryOrder(NativeLibraryConfig.Description configuration)
{
var avxStr = NativeLibraryConfig.AvxLevelToString(avxLevel);
if (!string.IsNullOrEmpty(avxStr))
{
avxStr += "/";
}
return $"{prefix}{avxStr}{libraryNamePrefix}{libraryName}{suffix}";
}
var loadingName = configuration.Library.GetLibraryName();
Log($"Loading library: '{loadingName}'", LLamaLogLevel.Debug);
private static List<string> GetLibraryTryOrder(NativeLibraryConfig.Description configuration)
{
OSPlatform platform;
string prefix, suffix, libraryNamePrefix;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
platform = OSPlatform.Windows;
prefix = "runtimes/win-x64/native/";
suffix = ".dll";
libraryNamePrefix = "";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
platform = OSPlatform.Linux;
prefix = "runtimes/linux-x64/native/";
suffix = ".so";
libraryNamePrefix = "lib";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
platform = OSPlatform.OSX;
suffix = ".dylib";
// Get platform specific parts of the path (e.g. .so/.dll/.dylib, libName prefix or not)
GetPlatformPathParts(out var platform, out var os, out var ext, out var libPrefix);
Log($"Detected OS Platform: '{platform}'", LLamaLogLevel.Info);
Log($"Detected OS string: '{os}'", LLamaLogLevel.Debug);
Log($"Detected extension string: '{ext}'", LLamaLogLevel.Debug);
Log($"Detected prefix string: '{libPrefix}'", LLamaLogLevel.Debug);
prefix = System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported
? "runtimes/osx-arm64/native/"
: "runtimes/osx-x64/native/";
libraryNamePrefix = "lib";
}
else
if (configuration.UseCuda && (platform == OSPlatform.Windows || platform == OSPlatform.Linux))
{
throw new RuntimeError("Your system plarform is not supported, please open an issue in LLamaSharp.");
}
Log($"Detected OS Platform: {platform}", LogLevel.Information);
var cudaVersion = GetCudaMajorVersion();
Log($"Detected cuda major version {cudaVersion}.", LLamaLogLevel.Info);
List<string> result = new();
if (configuration.UseCuda && (platform == OSPlatform.Windows || platform == OSPlatform.Linux)) // no cuda on macos
{
int cudaVersion = GetCudaMajorVersion();
// TODO: load cuda library with avx
if (cudaVersion == -1 && !configuration.AllowFallback)
{
// if check skipped, we just try to load cuda libraries one by one.
if (configuration.SkipCheck)
{
result.Add($"{prefix}cuda12/{libraryNamePrefix}{libraryName}{suffix}");
result.Add($"{prefix}cuda11/{libraryNamePrefix}{libraryName}{suffix}");
yield return GetCudaLibraryPath(loadingName, "cuda12");
yield return GetCudaLibraryPath(loadingName, "cuda11");
}
else
{
@ -198,121 +227,167 @@ namespace LLama.Native
}
else if (cudaVersion == 11)
{
Log($"Detected cuda major version {cudaVersion}.", LogLevel.Information);
result.Add($"{prefix}cuda11/{libraryNamePrefix}{libraryName}{suffix}");
yield return GetCudaLibraryPath(loadingName, "cuda11");
}
else if (cudaVersion == 12)
{
Log($"Detected cuda major version {cudaVersion}.", LogLevel.Information);
result.Add($"{prefix}cuda12/{libraryNamePrefix}{libraryName}{suffix}");
yield return GetCudaLibraryPath(loadingName, "cuda12");
}
else if (cudaVersion > 0)
{
throw new RuntimeError($"Cuda version {cudaVersion} hasn't been supported by LLamaSharp, please open an issue for it.");
}
// otherwise no cuda detected but allow fallback
}
// use cpu (or mac possibly with metal)
if (!configuration.AllowFallback && platform != OSPlatform.OSX)
{
result.Add(GetAvxLibraryPath(configuration.AvxLevel, prefix, suffix, libraryNamePrefix));
}
else if (platform != OSPlatform.OSX) // in macos there's absolutely no avx
{
if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx512)
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx512, prefix, suffix, libraryNamePrefix));
if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx2)
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix, libraryNamePrefix));
if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx)
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix, libraryNamePrefix));
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.None, prefix, suffix, libraryNamePrefix));
}
// Add the CPU/Metal libraries
if (platform == OSPlatform.OSX)
{
result.Add($"{prefix}{libraryNamePrefix}{libraryName}{suffix}");
result.Add($"{prefix}{libraryNamePrefix}{llavaLibraryName}{suffix}");
// On Mac it's very simple, there's no AVX to consider.
yield return GetMacLibraryPath(loadingName);
}
else
{
if (configuration.AllowFallback)
{
// Try all of the AVX levels we can support.
if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx512)
yield return GetAvxLibraryPath(loadingName, NativeLibraryConfig.AvxLevel.Avx512);
return result;
if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx2)
yield return GetAvxLibraryPath(loadingName, NativeLibraryConfig.AvxLevel.Avx2);
if (configuration.AvxLevel >= NativeLibraryConfig.AvxLevel.Avx)
yield return GetAvxLibraryPath(loadingName, NativeLibraryConfig.AvxLevel.Avx);
yield return GetAvxLibraryPath(loadingName, NativeLibraryConfig.AvxLevel.None);
}
else
{
// Fallback is not allowed - use the exact specified AVX level
yield return GetAvxLibraryPath(loadingName, configuration.AvxLevel);
}
}
}
private static string GetMacLibraryPath(string libraryName)
{
GetPlatformPathParts(out _, out var os, out var fileExtension, out var libPrefix);
return $"runtimes/{os}/native/{libPrefix}{libraryName}{fileExtension}";
}
/// <summary>
/// Given a CUDA version and some path parts, create a complete path to the library file
/// </summary>
/// <param name="libraryName">Library being loaded (e.g. "llama")</param>
/// <param name="cuda">CUDA version (e.g. "cuda11")</param>
/// <returns></returns>
private static string GetCudaLibraryPath(string libraryName, string cuda)
{
GetPlatformPathParts(out _, out var os, out var fileExtension, out var libPrefix);
return $"runtimes/{os}/native/{cuda}/{libPrefix}{libraryName}{fileExtension}";
}
/// <summary>
/// Given an AVX level and some path parts, create a complete path to the library file
/// </summary>
/// <param name="libraryName">Library being loaded (e.g. "llama")</param>
/// <param name="avx"></param>
/// <returns></returns>
private static string GetAvxLibraryPath(string libraryName, NativeLibraryConfig.AvxLevel avx)
{
GetPlatformPathParts(out _, out var os, out var fileExtension, out var libPrefix);
var avxStr = NativeLibraryConfig.AvxLevelToString(avx);
if (!string.IsNullOrEmpty(avxStr))
avxStr += "/";
return $"runtimes/{os}/native/{avxStr}{libPrefix}{libraryName}{fileExtension}";
}
private static void GetPlatformPathParts(out OSPlatform platform, out string os, out string fileExtension, out string libPrefix)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
platform = OSPlatform.Windows;
os = "win-x64";
fileExtension = ".dll";
libPrefix = "";
return;
}
if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
platform = OSPlatform.Linux;
os = "linux-x64";
fileExtension = ".so";
libPrefix = "lib";
return;
}
if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
platform = OSPlatform.OSX;
fileExtension = ".dylib";
os = System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported
? "osx-arm64"
: "osx-x64";
libPrefix = "lib";
}
else
{
throw new RuntimeError("Your operating system is not supported, please open an issue in LLamaSharp.");
}
}
#endif
/// <summary>
/// Try to load libllama, using CPU feature detection to try and load a more specialised DLL if possible
/// Try to load libllama/llava_shared, using CPU feature detection to try and load a more specialised DLL if possible
/// </summary>
/// <returns>The library handle to unload later, or IntPtr.Zero if no library was loaded</returns>
private static IntPtr TryLoadLibrary()
private static IntPtr TryLoadLibraries(LibraryName lib)
{
#if NET6_0_OR_GREATER
var configuration = NativeLibraryConfig.CheckAndGatherDescription();
var configuration = NativeLibraryConfig.CheckAndGatherDescription(lib);
enableLogging = configuration.Logging;
logLevel = configuration.LogLevel;
// We move the flag to avoid loading library when the variable is called else where.
NativeLibraryConfig.LibraryHasLoaded = true;
Log(configuration.ToString(), LogLevel.Information);
// Set the flag to ensure the NativeLibraryConfig can no longer be modified
NativeLibraryConfig.LibraryHasLoaded = true;
// Show the configuration we're working with
Log(configuration.ToString(), LLamaLogLevel.Info);
// If a specific path is requested, load that or immediately fail
if (!string.IsNullOrEmpty(configuration.Path))
{
// When loading the user specified library, there's no fallback.
var success = NativeLibrary.TryLoad(configuration.Path, out var result);
if (!success)
{
if (!NativeLibrary.TryLoad(configuration.Path, out var handle))
throw new RuntimeError($"Failed to load the native library [{configuration.Path}] you specified.");
}
Log($"Successfully loaded the library [{configuration.Path}] specified by user", LogLevel.Information);
return result;
Log($"Successfully loaded the library [{configuration.Path}] specified by user", LLamaLogLevel.Info);
return handle;
}
// Get a list of locations to try loading (in order of preference)
var libraryTryLoadOrder = GetLibraryTryOrder(configuration);
var preferredPaths = configuration.SearchDirectories;
var possiblePathPrefix = new[] {
AppDomain.CurrentDomain.BaseDirectory,
Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? ""
};
string TryFindPath(string filename)
{
foreach (var path in preferredPaths)
{
if (File.Exists(Path.Combine(path, filename)))
{
return Path.Combine(path, filename);
}
}
foreach (var path in possiblePathPrefix)
{
if (File.Exists(Path.Combine(path, filename)))
{
return Path.Combine(path, filename);
}
}
return filename;
}
foreach (var libraryPath in libraryTryLoadOrder)
{
var fullPath = TryFindPath(libraryPath);
var result = TryLoad(fullPath, true);
if (result is not null && result != IntPtr.Zero)
Log($"Trying '{fullPath}'", LLamaLogLevel.Debug);
var result = TryLoad(fullPath);
if (result != IntPtr.Zero)
{
Log($"{fullPath} is selected and loaded successfully.", LogLevel.Information);
// One we have clear the detection and that llama loads successfully we load LLaVa if exist on the
// same path.
TryLoad( libraryPath.Replace("llama", "llava_shared"), true);
return (IntPtr)result;
Log($"Loaded '{fullPath}'", LLamaLogLevel.Info);
return result;
}
Log($"Tried to load {fullPath} but failed.", LogLevel.Information);
Log($"Failed Loading '{fullPath}'", LLamaLogLevel.Info);
}
if (!configuration.AllowFallback)
@ -325,20 +400,45 @@ namespace LLama.Native
#endif
Log($"No library was loaded before calling native apis. " +
$"This is not an error under netstandard2.0 but needs attention with net6 or higher.", LogLevel.Warning);
$"This is not an error under netstandard2.0 but needs attention with net6 or higher.", LLamaLogLevel.Warning);
return IntPtr.Zero;
#if NET6_0_OR_GREATER
// Try to load a DLL from the path if supported. Returns null if nothing is loaded.
static IntPtr? TryLoad(string path, bool supported = true)
// Try to load a DLL from the path.
// Returns null if nothing is loaded.
static IntPtr TryLoad(string path)
{
if (!supported)
return null;
if (NativeLibrary.TryLoad(path, out var handle))
return handle;
return null;
return IntPtr.Zero;
}
// Try to find the given file in any of the possible search paths
string TryFindPath(string filename)
{
// Try the configured search directories in the configuration
foreach (var path in configuration.SearchDirectories)
{
var candidate = Path.Combine(path, filename);
if (File.Exists(candidate))
return candidate;
}
// Try a few other possible paths
var possiblePathPrefix = new[] {
AppDomain.CurrentDomain.BaseDirectory,
Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? ""
};
foreach (var path in possiblePathPrefix)
{
var candidate = Path.Combine(path, filename);
if (File.Exists(candidate))
return candidate;
}
return filename;
}
#endif
}

View File

@ -11,19 +11,19 @@ namespace LLama.Native
/// </summary>
public sealed class NativeLibraryConfig
{
private static readonly Lazy<NativeLibraryConfig> _instance = new(() => new NativeLibraryConfig());
/// <summary>
/// Get the config instance
/// </summary>
public static NativeLibraryConfig Instance => _instance.Value;
public static NativeLibraryConfig Instance { get; } = new();
/// <summary>
/// Whether there's already a config for native library.
/// Check if the native library has already been loaded. Configuration cannot be modified if this is true.
/// </summary>
public static bool LibraryHasLoaded { get; internal set; } = false;
private string _libraryPath = string.Empty;
private string? _libraryPath;
private string? _libraryPathLLava;
private bool _useCuda = true;
private AvxLevel _avxLevel;
private bool _allowFallback = true;
@ -42,17 +42,20 @@ namespace LLama.Native
throw new InvalidOperationException("NativeLibraryConfig must be configured before using **any** other LLamaSharp methods!");
}
#region configurators
/// <summary>
/// Load a specified native library as backend for LLamaSharp.
/// When this method is called, all the other configurations will be ignored.
/// </summary>
/// <param name="libraryPath"></param>
/// <param name="llamaPath">The full path to the llama library to load.</param>
/// <param name="llavaPath">The full path to the llava library to load.</param>
/// <exception cref="InvalidOperationException">Thrown if `LibraryHasLoaded` is true.</exception>
public NativeLibraryConfig WithLibrary(string libraryPath)
public NativeLibraryConfig WithLibrary(string? llamaPath, string? llavaPath)
{
ThrowIfLoaded();
_libraryPath = libraryPath;
_libraryPath = llamaPath;
_libraryPathLLava = llavaPath;
return this;
}
@ -172,14 +175,23 @@ namespace LLama.Native
_searchDirectories.Add(directory);
return this;
}
#endregion
internal static Description CheckAndGatherDescription()
internal static Description CheckAndGatherDescription(LibraryName library)
{
if (Instance._allowFallback && Instance._skipCheck)
throw new ArgumentException("Cannot skip the check when fallback is allowed.");
var path = library switch
{
LibraryName.Llama => Instance._libraryPath,
LibraryName.LlavaShared => Instance._libraryPathLLava,
_ => throw new ArgumentException($"Unknown library name '{library}'", nameof(library)),
};
return new Description(
Instance._libraryPath,
path,
library,
Instance._useCuda,
Instance._avxLevel,
Instance._allowFallback,
@ -267,7 +279,7 @@ namespace LLama.Native
Avx512,
}
internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging, LLamaLogLevel LogLevel, string[] SearchDirectories)
internal record Description(string? Path, LibraryName Library, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging, LLamaLogLevel LogLevel, string[] SearchDirectories)
{
public override string ToString()
{
@ -283,7 +295,8 @@ namespace LLama.Native
string searchDirectoriesString = "{ " + string.Join(", ", SearchDirectories) + " }";
return $"NativeLibraryConfig Description:\n" +
$"- Path: {Path}\n" +
$"- LibraryName: {Library}\n" +
$"- Path: '{Path}'\n" +
$"- PreferCuda: {UseCuda}\n" +
$"- PreferredAvxLevel: {avxLevelString}\n" +
$"- AllowFallback: {AllowFallback}\n" +
@ -295,4 +308,26 @@ namespace LLama.Native
}
}
#endif
internal enum LibraryName
{
Llama,
LlavaShared
}
internal static class LibraryNameExtensions
{
public static string GetLibraryName(this LibraryName name)
{
switch (name)
{
case LibraryName.Llama:
return NativeApi.libraryName;
case LibraryName.LlavaShared:
return NativeApi.llavaLibraryName;
default:
throw new ArgumentOutOfRangeException(nameof(name), name, null);
}
}
}
}