mirror of https://github.com/tracel-ai/burn.git
Enable candle cuda (#887)
This commit is contained in:
parent
80fe58c604
commit
86db5dc392
|
@ -12,7 +12,8 @@ version = "0.10.0"
|
|||
[features]
|
||||
default = ["std"]
|
||||
std = []
|
||||
candle = ["burn/candle"]
|
||||
candle-cpu = ["burn/candle"]
|
||||
candle-cuda = ["burn/candle-cuda"]
|
||||
ndarray = ["burn/ndarray"]
|
||||
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
|
||||
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]
|
||||
|
|
|
@ -41,7 +41,7 @@ macro_rules! bench_on_backend {
|
|||
bench::<NdArrayBackend>(&device);
|
||||
}
|
||||
|
||||
#[cfg(feature = "candle")]
|
||||
#[cfg(feature = "candle-cpu")]
|
||||
{
|
||||
use burn::backend::candle::CandleDevice;
|
||||
use burn::backend::CandleBackend;
|
||||
|
@ -49,5 +49,14 @@ macro_rules! bench_on_backend {
|
|||
let device = CandleDevice::Cpu;
|
||||
bench::<CandleBackend>(&device);
|
||||
}
|
||||
|
||||
#[cfg(feature = "candle-cuda")]
|
||||
{
|
||||
use burn::backend::candle::CandleDevice;
|
||||
use burn::backend::CandleBackend;
|
||||
|
||||
let device = CandleDevice::Cuda(0);
|
||||
bench::<CandleBackend>(&device);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-candle"
|
|||
version = "0.10.0"
|
||||
|
||||
[features]
|
||||
cuda = ["candle-core/cuda"]
|
||||
|
||||
[dependencies]
|
||||
derive-new = { workspace = true }
|
||||
|
|
|
@ -53,6 +53,7 @@ wgpu = ["burn-wgpu/default"]
|
|||
tch = ["burn-tch"]
|
||||
|
||||
candle = ["burn-candle"]
|
||||
candle-cuda = ["candle", "burn-candle/cuda"]
|
||||
|
||||
# Serialization formats
|
||||
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
||||
|
|
|
@ -46,6 +46,7 @@ ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"]
|
|||
wgpu = ["burn-core/wgpu"]
|
||||
tch = ["burn-core/tch"]
|
||||
candle = ["burn-core/candle"]
|
||||
candle-cuda = ["burn-core/candle-cuda"]
|
||||
|
||||
# Experimental
|
||||
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
|
||||
|
|
Loading…
Reference in New Issue