burn/crates/burn-tch/build.rs

244 lines
8.8 KiB
Rust

// The LIBTORCH environment variable can be used to specify the directory
// where libtorch has been installed.
// When not specified this script downloads the cpu version for libtorch
// and extracts it in OUT_DIR.
//
// On Linux, the TORCH_CUDA_VERSION environment variable can be used,
// like 9.0, 90, or cu90 to specify the version of CUDA to use for libtorch.
use std::path::{Path, PathBuf};
use std::{env, fs};
const PYTHON_PRINT_PYTORCH_DETAILS: &str = r"
import torch
from torch.utils import cpp_extension
print('LIBTORCH_VERSION:', torch.__version__.split('+')[0])
print('LIBTORCH_CXX11:', torch._C._GLIBCXX_USE_CXX11_ABI)
for include_path in cpp_extension.include_paths():
print('LIBTORCH_INCLUDE:', include_path)
for library_path in cpp_extension.library_paths():
print('LIBTORCH_LIB:', library_path)
";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Os {
Linux,
Macos,
Windows,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct SystemInfo {
os: Os,
cxx11_abi: String,
libtorch_include_dirs: Vec<PathBuf>,
libtorch_lib_dir: PathBuf,
}
fn env_var_rerun(name: &str) -> Result<String, env::VarError> {
println!("cargo:rerun-if-env-changed={name}");
env::var(name)
}
impl SystemInfo {
fn new() -> Option<Self> {
let os = match env::var("CARGO_CFG_TARGET_OS")
.expect("Unable to get TARGET_OS")
.as_str()
{
"linux" => Os::Linux,
"windows" => Os::Windows,
"macos" => Os::Macos,
os => panic!("unsupported TARGET_OS '{os}'"),
};
// Locate the currently active Python binary, similar to:
// https://github.com/PyO3/maturin/blob/243b8ec91d07113f97a6fe74d9b2dcb88086e0eb/src/target.rs#L547
let python_interpreter = match os {
Os::Windows => PathBuf::from("python.exe"),
Os::Linux | Os::Macos => {
if env::var_os("VIRTUAL_ENV").is_some() {
PathBuf::from("python")
} else {
PathBuf::from("python3")
}
}
};
let mut libtorch_include_dirs = vec![];
let mut libtorch_lib_dir = None;
let cxx11_abi = if env_var_rerun("LIBTORCH_USE_PYTORCH").is_ok() {
let output = std::process::Command::new(&python_interpreter)
.arg("-c")
.arg(PYTHON_PRINT_PYTORCH_DETAILS)
.output()
.expect("error running python interpreter");
let mut cxx11_abi = None;
for line in String::from_utf8_lossy(&output.stdout).lines() {
match line.strip_prefix("LIBTORCH_CXX11: ") {
Some("True") => cxx11_abi = Some("1".to_owned()),
Some("False") => cxx11_abi = Some("0".to_owned()),
_ => {}
}
if let Some(path) = line.strip_prefix("LIBTORCH_INCLUDE: ") {
libtorch_include_dirs.push(PathBuf::from(path))
}
if let Some(path) = line.strip_prefix("LIBTORCH_LIB: ") {
libtorch_lib_dir = Some(PathBuf::from(path))
}
}
match cxx11_abi {
Some(cxx11_abi) => cxx11_abi,
None => panic!("no cxx11 abi returned by python {output:?}"),
}
} else {
let libtorch = Self::prepare_libtorch_dir(os)?;
let includes = env_var_rerun("LIBTORCH_INCLUDE")
.map(PathBuf::from)
.unwrap_or_else(|_| libtorch.clone());
let lib = env_var_rerun("LIBTORCH_LIB")
.map(PathBuf::from)
.unwrap_or_else(|_| libtorch.clone());
libtorch_include_dirs.push(includes.join("include"));
libtorch_include_dirs.push(includes.join("include/torch/csrc/api/include"));
if lib.ends_with("lib") {
// DEP_TCH_LIBTORCH_LIB might already point to /lib
libtorch_lib_dir = Some(lib);
} else {
libtorch_lib_dir = Some(lib.join("lib"));
}
env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned())
};
let libtorch_lib_dir = libtorch_lib_dir?;
Some(Self {
os,
cxx11_abi,
libtorch_include_dirs,
libtorch_lib_dir,
})
}
fn check_system_location(os: Os) -> Option<PathBuf> {
match os {
Os::Linux => Path::new("/usr/lib/libtorch.so")
.exists()
.then(|| PathBuf::from("/usr")),
_ => None,
}
}
fn prepare_libtorch_dir(os: Os) -> Option<PathBuf> {
if let Ok(libtorch) = env_var_rerun("DEP_TCH_LIBTORCH_LIB") {
Some(PathBuf::from(libtorch))
} else if let Ok(libtorch) = env_var_rerun("LIBTORCH") {
Some(PathBuf::from(libtorch))
} else if let Some(pathbuf) = Self::check_system_location(os) {
Some(pathbuf)
} else {
check_out_dir()
}
}
fn make(&self, use_cuda: bool, use_hip: bool) {
let cuda_dependency = if use_cuda || use_hip {
"src/cuda_hack/dummy_cuda_dependency.cpp"
} else {
"src/cuda_hack/fake_cuda_dependency.cpp"
};
println!("cargo:rerun-if-changed={cuda_dependency}");
match self.os {
Os::Linux | Os::Macos => {
cc::Build::new()
.cpp(true)
.pic(true)
.warnings(false)
.includes(&self.libtorch_include_dirs)
.flag(format!("-Wl,-rpath={}", self.libtorch_lib_dir.display()))
.flag("-std=c++17")
.flag(format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi))
.files(&[cuda_dependency])
.compile("burn-tch");
}
Os::Windows => {
cc::Build::new()
.cpp(true)
.pic(true)
.warnings(false)
.includes(&self.libtorch_include_dirs)
.flag("/std:c++17")
.files(&[cuda_dependency])
.compile("burn-tch");
}
};
}
fn make_cpu() {
let cuda_dependency = "src/cuda_hack/fake_cuda_dependency.cpp";
println!("cargo:rerun-if-changed={cuda_dependency}");
let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");
match os.as_str() {
"windows" => {
cc::Build::new()
.cpp(true)
.pic(true)
.warnings(false)
.flag("/std:c++17")
.files(&[cuda_dependency])
.compile("burn-tch");
}
_ => {
cc::Build::new()
.cpp(true)
.pic(true)
.warnings(false)
.flag("-std=c++17")
.files(&[cuda_dependency])
.compile("tch");
}
};
}
}
fn check_out_dir() -> Option<PathBuf> {
let out_dir = env_var_rerun("OUT_DIR").ok()?;
let libtorch_dir = PathBuf::from(out_dir).join("libtorch");
libtorch_dir.exists().then_some(libtorch_dir)
}
fn main() {
let system_info = SystemInfo::new();
let out_dir = env_var_rerun("OUT_DIR").expect("Failed to get out dir");
let mut gpu_found = false;
let found_dir = system_info.is_some();
if let Some(system_info) = &system_info {
let si_lib = &system_info.libtorch_lib_dir;
let use_cuda =
si_lib.join("libtorch_cuda.so").exists() || si_lib.join("torch_cuda.dll").exists();
let use_hip =
si_lib.join("libtorch_hip.so").exists() || si_lib.join("torch_hip.dll").exists();
system_info.make(use_cuda, use_hip);
gpu_found = use_cuda || use_hip;
} else {
SystemInfo::make_cpu();
}
let check_file = PathBuf::from(out_dir).join("tch_gpu_check.rs");
if gpu_found {
fs::write(check_file, "#[allow(clippy::no_effect)]\n()").unwrap();
} else {
let message = if !found_dir {
r#"Could not find libtorch dir.
If you are trying to use the automatically downloaded version, the path is not directly available on Windows. Instead, try setting the `LIBTORCH` environment variable for the manual download instructions.
If the library has already been downloaded in the torch-sys OUT_DIR, you can point the variable to this path (or move the downloaded lib and point to it)."#
} else {
"No libtorch_cuda or libtorch_hip found. Download the GPU version of libtorch to use a GPU device"
};
fs::write(check_file, format!("panic!(\"{message}\")")).unwrap();
}
}