make candle available (#886)

This commit is contained in:
Louis Fortier-Dubois 2023-10-23 10:00:39 -04:00 committed by GitHub
parent 07c0cf146d
commit e4d9d67526
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 34 additions and 7 deletions

View File

@ -6,6 +6,7 @@ resolver = "2"
members = [ members = [
"burn", "burn",
"burn-autodiff", "burn-autodiff",
"burn-candle",
"burn-common", "burn-common",
"burn-compute", "burn-compute",
"burn-core", "burn-core",

View File

@ -26,7 +26,7 @@ simplifying the process of experimenting, training, and deploying models.
[`no_std`](#support-for-no_std) compatibility, ensuring universal platform adaptability 👌 [`no_std`](#support-for-no_std) compatibility, ensuring universal platform adaptability 👌
- [WebGPU](https://github.com/burn-rs/burn/tree/main/burn-wgpu) backend, offering cross-platform, - [WebGPU](https://github.com/burn-rs/burn/tree/main/burn-wgpu) backend, offering cross-platform,
browser-inclusive, GPU-based computations 🌐 browser-inclusive, GPU-based computations 🌐
- [Candle](https://github.com/burn-rs/burn/tree/main/burn-candle) backend (alpha) 🕯️ - [Candle](https://github.com/burn-rs/burn/tree/main/burn-candle) backend 🕯️
- [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend that enables - [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend that enables
differentiability across all backends 🌟 differentiability across all backends 🌟
- [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate containing a diverse range - [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate containing a diverse range

View File

@ -12,6 +12,7 @@ version = "0.10.0"
[features] [features]
default = ["std"] default = ["std"]
std = [] std = []
candle = ["burn/candle"]
ndarray = ["burn/ndarray"] ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"] ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"] ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]

View File

@ -40,5 +40,14 @@ macro_rules! bench_on_backend {
let device = NdArrayDevice::Cpu; let device = NdArrayDevice::Cpu;
bench::<NdArrayBackend>(&device); bench::<NdArrayBackend>(&device);
} }
#[cfg(feature = "candle")]
{
use burn::backend::candle::CandleDevice;
use burn::backend::CandleBackend;
let device = CandleDevice::Cpu;
bench::<CandleBackend>(&device);
}
}; };
} }

View File

@ -52,6 +52,8 @@ wgpu = ["burn-wgpu/default"]
tch = ["burn-tch"] tch = ["burn-tch"]
candle = ["burn-candle"]
# Serialization formats # Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
@ -72,6 +74,7 @@ burn-ndarray = { path = "../burn-ndarray", version = "0.10.0", optional = true,
burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true } burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true }
burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true } burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true }
burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true } burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true }
burn-candle = { path = "../burn-candle", version = "0.10.0", optional = true }
derive-new = { workspace = true } derive-new = { workspace = true }
libm = { workspace = true } libm = { workspace = true }

View File

@ -23,6 +23,18 @@ pub type WgpuBackend<G = wgpu::AutoGraphicsApi, F = f32, I = i32> = wgpu::WgpuBa
pub type WgpuAutodiffBackend<G = wgpu::AutoGraphicsApi, F = f32, I = i32> = pub type WgpuAutodiffBackend<G = wgpu::AutoGraphicsApi, F = f32, I = i32> =
crate::autodiff::ADBackendDecorator<WgpuBackend<G, F, I>>; crate::autodiff::ADBackendDecorator<WgpuBackend<G, F, I>>;
#[cfg(feature = "candle")]
/// Candle module.
pub use burn_candle as candle;
#[cfg(feature = "candle")]
/// A CandleBackend with a default type of f32/i64.
pub type CandleBackend = candle::CandleBackend<f32, i64>;
#[cfg(all(feature = "candle", feature = "autodiff"))]
/// A CandleBackend with autodiffing enabled.
pub type CandleAutodiffBackend = crate::autodiff::ADBackendDecorator<CandleBackend>;
#[cfg(feature = "tch")] #[cfg(feature = "tch")]
/// Tch module. /// Tch module.
pub use burn_tch as tch; pub use burn_tch as tch;

View File

@ -5,7 +5,7 @@
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-wgpu.svg)](https://crates.io/crates/burn-wgpu) [![Current Crates.io Version](https://img.shields.io/crates/v/burn-wgpu.svg)](https://crates.io/crates/burn-wgpu)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn-wgpu/blob/master/README.md) [![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn-wgpu/blob/master/README.md)
This crate provides a WGPU backend for [Burn](https://github.com/burn-rs/burn) utilizing the This crate provides a WGPU backend for [Burn](https://github.com/burn-rs/burn) using the
[wgpu](https://github.com/gfx-rs/wgpu). [wgpu](https://github.com/gfx-rs/wgpu).
The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU. The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU.

View File

@ -45,6 +45,7 @@ ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"]
wgpu = ["burn-core/wgpu"] wgpu = ["burn-core/wgpu"]
tch = ["burn-core/tch"] tch = ["burn-core/tch"]
candle = ["burn-core/candle"]
# Experimental # Experimental
experimental-named-tensor = ["burn-core/experimental-named-tensor"] experimental-named-tensor = ["burn-core/experimental-named-tensor"]