Merge pull request #365 from Onkitova/preview

feat: using CUDA while decoupling from the CUDA Toolkit as a hard-dependency
This commit is contained in:
Rinne 2023-12-15 08:52:55 +08:00 committed by GitHub
commit b79387fd76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 67 additions and 2 deletions

View File

@ -7,6 +7,7 @@ using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text.Json;
using System.Text.RegularExpressions;
namespace LLama.Native
{
@ -69,9 +70,12 @@ namespace LLama.Native
cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH");
if (cudaPath is null)
{
return -1;
version = GetCudaVersionFromDriverUtils_windows();
}
else
{
version = GetCudaVersionFromPath(cudaPath);
}
version = GetCudaVersionFromPath(cudaPath);
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
@ -115,6 +119,67 @@ namespace LLama.Native
}
}
private static string GetCudaVersionFromDriverUtils_windows()
{
try
{
var psi = new ProcessStartInfo
{
FileName = "nvidia-smi",
RedirectStandardOutput = true,
UseShellExecute = false,
CreateNoWindow = true
};
using (var process = Process.Start(psi))
{
if (process != null)
{
using (StreamReader reader = process.StandardOutput)
{
string output = reader.ReadToEnd();
process.WaitForExit();
string cudaVersion = GetNvidiaSmiValue(output, "CUDA Version");
string pattern = @":\s(\d+\.\d+)";
Match match = Regex.Match(cudaVersion, pattern);
string extractedValue = string.Empty;
if (match.Success && match.Groups.Count > 1)
{
extractedValue = match.Groups[1].Value;
}
return extractedValue;
}
}
else
{
return string.Empty;
}
}
}
catch (Exception)
{
return string.Empty;
}
}
static string GetNvidiaSmiValue(string nvidiaSmiOutput, string key)
{
int startIndex = nvidiaSmiOutput.IndexOf(key);
if (startIndex == -1)
{
return "N/A";
}
startIndex += key.Length;
int endIndex = nvidiaSmiOutput.IndexOf('\n', startIndex);
if (endIndex == -1)
{
endIndex = nvidiaSmiOutput.Length;
}
string value = nvidiaSmiOutput.Substring(startIndex, endIndex - startIndex).Trim();
return value;
}
private static string GetCudaVersionFromPath(string cudaPath)
{
try