diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 2aeed8920..a7b532bd1 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -14,6 +14,7 @@ default = ["std"] std = [] candle-cpu = ["burn/candle"] candle-cuda = ["burn/candle-cuda"] +candle-accelerate = ["burn/candle-accelerate"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"] ndarray-blas-netlib = ["burn/ndarray-blas-netlib"] diff --git a/burn-candle/Cargo.toml b/burn-candle/Cargo.toml index bad485f5f..8b7e3f251 100644 --- a/burn-candle/Cargo.toml +++ b/burn-candle/Cargo.toml @@ -12,6 +12,7 @@ version = "0.11.0" [features] cuda = ["candle-core/cuda"] +accelerate = ["candle-core/accelerate"] [dependencies] derive-new = { workspace = true } diff --git a/burn-candle/README.md b/burn-candle/README.md index e2e559726..79677b644 100644 --- a/burn-candle/README.md +++ b/burn-candle/README.md @@ -4,4 +4,11 @@ This crate provides a backend for [Burn](https://github.com/burn-rs/burn) based It is still in alpha stage, not all operations are supported. It is usable for some use cases, like for inference. -It can be used with CPU or CUDA. \ No newline at end of file +It can be used with CPU or CUDA. On macOS computations can be accelerated by using the Accelerate framework. + +## Feature Flags + +The following features are supported: + +- `cuda` - Cuda GPU device (NVIDIA only) +- `accelerate` - Accelerate framework (macOS only) diff --git a/burn-candle/src/ops/tensor.rs b/burn-candle/src/ops/tensor.rs index e3ab04b94..3880f622b 100644 --- a/burn-candle/src/ops/tensor.rs +++ b/burn-candle/src/ops/tensor.rs @@ -137,7 +137,8 @@ impl TensorOps for Candle, rhs: FloatTensor, ) -> FloatTensor { - CandleTensor::new(lhs.tensor.broadcast_matmul(&rhs.tensor).unwrap()) + let rhs_contiguous = rhs.tensor.contiguous().unwrap(); + CandleTensor::new(lhs.tensor.broadcast_matmul(&rhs_contiguous).unwrap()) } fn swap_dims( diff --git a/burn-core/Cargo.toml b/burn-core/Cargo.toml index b692a2a18..14aa52e4b 100644 --- a/burn-core/Cargo.toml +++ b/burn-core/Cargo.toml @@ -55,6 +55,7 @@ tch = ["burn-tch"] candle = ["burn-candle"] candle-cuda = ["candle", "burn-candle/cuda"] +candle-accelerate = ["candle", "burn-candle/accelerate"] # Serialization formats experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] diff --git a/burn/Cargo.toml b/burn/Cargo.toml index 84ace64a7..99d0e218e 100644 --- a/burn/Cargo.toml +++ b/burn/Cargo.toml @@ -51,6 +51,7 @@ wgpu = ["burn-core/wgpu"] tch = ["burn-core/tch"] candle = ["burn-core/candle"] candle-cuda = ["burn-core/candle-cuda"] +candle-accelerate = ["burn-core/candle-accelerate"] # Experimental experimental-named-tensor = ["burn-core/experimental-named-tensor"] diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 6cc410519..1c95467ba 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -249,6 +249,13 @@ fn burn_dataset_features_std() { cargo_doc(["-p", "burn-dataset", "--all-features"].into()); } +// Test burn-candle with accelerate (macOS only) +// Leverages the macOS Accelerate framework: https://developer.apple.com/documentation/accelerate +#[cfg(target_os = "macos")] +fn burn_candle_accelerate() { + cargo_test(["-p", "burn-candle", "--features", "accelerate"].into()); +} + fn std_checks() { // Set RUSTDOCFLAGS environment variable to treat warnings as errors // for the documentation build @@ -284,6 +291,10 @@ fn std_checks() { // Test each workspace cargo_test(["--workspace"].into()); + // Test burn-candle with accelerate (macOS only) + #[cfg(target_os = "macos")] + burn_candle_accelerate(); + // Test burn-dataset features burn_dataset_features_std();