Enable candle cuda (#887)

This commit is contained in:
Nathaniel Simard 2023-10-23 11:00:54 -04:00 committed by GitHub
parent 80fe58c604
commit 86db5dc392
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 15 additions and 2 deletions

View File

@ -12,7 +12,8 @@ version = "0.10.0"
[features]
default = ["std"]
std = []
candle = ["burn/candle"]
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle-cuda"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]

View File

@ -41,7 +41,7 @@ macro_rules! bench_on_backend {
bench::<NdArrayBackend>(&device);
}
#[cfg(feature = "candle")]
#[cfg(feature = "candle-cpu")]
{
use burn::backend::candle::CandleDevice;
use burn::backend::CandleBackend;
@ -49,5 +49,14 @@ macro_rules! bench_on_backend {
let device = CandleDevice::Cpu;
bench::<CandleBackend>(&device);
}
#[cfg(feature = "candle-cuda")]
{
use burn::backend::candle::CandleDevice;
use burn::backend::CandleBackend;
let device = CandleDevice::Cuda(0);
bench::<CandleBackend>(&device);
}
};
}

View File

@ -11,6 +11,7 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-candle"
version = "0.10.0"
[features]
cuda = ["candle-core/cuda"]
[dependencies]
derive-new = { workspace = true }

View File

@ -53,6 +53,7 @@ wgpu = ["burn-wgpu/default"]
tch = ["burn-tch"]
candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]

View File

@ -46,6 +46,7 @@ ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"]
wgpu = ["burn-core/wgpu"]
tch = ["burn-core/tch"]
candle = ["burn-core/candle"]
candle-cuda = ["burn-core/candle-cuda"]
# Experimental
experimental-named-tensor = ["burn-core/experimental-named-tensor"]