From 86db5dc392b21a7721a69d494578be125fdb11a8 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Mon, 23 Oct 2023 11:00:54 -0400 Subject: [PATCH] Enable candle cuda (#887) --- backend-comparison/Cargo.toml | 3 ++- backend-comparison/src/lib.rs | 11 ++++++++++- burn-candle/Cargo.toml | 1 + burn-core/Cargo.toml | 1 + burn/Cargo.toml | 1 + 5 files changed, 15 insertions(+), 2 deletions(-) diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 8e2928fb0..e2557817c 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -12,7 +12,8 @@ version = "0.10.0" [features] default = ["std"] std = [] -candle = ["burn/candle"] +candle-cpu = ["burn/candle"] +candle-cuda = ["burn/candle-cuda"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"] ndarray-blas-netlib = ["burn/ndarray-blas-netlib"] diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index 064a87cbe..efd606722 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -41,7 +41,7 @@ macro_rules! bench_on_backend { bench::(&device); } - #[cfg(feature = "candle")] + #[cfg(feature = "candle-cpu")] { use burn::backend::candle::CandleDevice; use burn::backend::CandleBackend; @@ -49,5 +49,14 @@ macro_rules! bench_on_backend { let device = CandleDevice::Cpu; bench::(&device); } + + #[cfg(feature = "candle-cuda")] + { + use burn::backend::candle::CandleDevice; + use burn::backend::CandleBackend; + + let device = CandleDevice::Cuda(0); + bench::(&device); + } }; } diff --git a/burn-candle/Cargo.toml b/burn-candle/Cargo.toml index 44a834dbf..f51a41f69 100644 --- a/burn-candle/Cargo.toml +++ b/burn-candle/Cargo.toml @@ -11,6 +11,7 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-candle" version = "0.10.0" [features] +cuda = ["candle-core/cuda"] [dependencies] derive-new = { workspace = true } diff --git a/burn-core/Cargo.toml b/burn-core/Cargo.toml index 4c1abca92..dd043b7e8 100644 --- a/burn-core/Cargo.toml +++ b/burn-core/Cargo.toml @@ -53,6 +53,7 @@ wgpu = ["burn-wgpu/default"] tch = ["burn-tch"] candle = ["burn-candle"] +candle-cuda = ["candle", "burn-candle/cuda"] # Serialization formats experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] diff --git a/burn/Cargo.toml b/burn/Cargo.toml index 04dd15f6f..663d4785f 100644 --- a/burn/Cargo.toml +++ b/burn/Cargo.toml @@ -46,6 +46,7 @@ ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"] wgpu = ["burn-core/wgpu"] tch = ["burn-core/tch"] candle = ["burn-core/candle"] +candle-cuda = ["burn-core/candle-cuda"] # Experimental experimental-named-tensor = ["burn-core/experimental-named-tensor"]