mirror of https://github.com/tracel-ai/burn.git
make candle available (#886)
This commit is contained in:
parent
07c0cf146d
commit
e4d9d67526
|
@ -6,6 +6,7 @@ resolver = "2"
|
||||||
members = [
|
members = [
|
||||||
"burn",
|
"burn",
|
||||||
"burn-autodiff",
|
"burn-autodiff",
|
||||||
|
"burn-candle",
|
||||||
"burn-common",
|
"burn-common",
|
||||||
"burn-compute",
|
"burn-compute",
|
||||||
"burn-core",
|
"burn-core",
|
||||||
|
|
|
@ -26,7 +26,7 @@ simplifying the process of experimenting, training, and deploying models.
|
||||||
[`no_std`](#support-for-no_std) compatibility, ensuring universal platform adaptability 👌
|
[`no_std`](#support-for-no_std) compatibility, ensuring universal platform adaptability 👌
|
||||||
- [WebGPU](https://github.com/burn-rs/burn/tree/main/burn-wgpu) backend, offering cross-platform,
|
- [WebGPU](https://github.com/burn-rs/burn/tree/main/burn-wgpu) backend, offering cross-platform,
|
||||||
browser-inclusive, GPU-based computations 🌐
|
browser-inclusive, GPU-based computations 🌐
|
||||||
- [Candle](https://github.com/burn-rs/burn/tree/main/burn-candle) backend (alpha) 🕯️
|
- [Candle](https://github.com/burn-rs/burn/tree/main/burn-candle) backend 🕯️
|
||||||
- [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend that enables
|
- [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend that enables
|
||||||
differentiability across all backends 🌟
|
differentiability across all backends 🌟
|
||||||
- [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate containing a diverse range
|
- [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate containing a diverse range
|
||||||
|
|
|
@ -12,6 +12,7 @@ version = "0.10.0"
|
||||||
[features]
|
[features]
|
||||||
default = ["std"]
|
default = ["std"]
|
||||||
std = []
|
std = []
|
||||||
|
candle = ["burn/candle"]
|
||||||
ndarray = ["burn/ndarray"]
|
ndarray = ["burn/ndarray"]
|
||||||
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
|
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
|
||||||
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]
|
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]
|
||||||
|
|
|
@ -40,5 +40,14 @@ macro_rules! bench_on_backend {
|
||||||
let device = NdArrayDevice::Cpu;
|
let device = NdArrayDevice::Cpu;
|
||||||
bench::<NdArrayBackend>(&device);
|
bench::<NdArrayBackend>(&device);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "candle")]
|
||||||
|
{
|
||||||
|
use burn::backend::candle::CandleDevice;
|
||||||
|
use burn::backend::CandleBackend;
|
||||||
|
|
||||||
|
let device = CandleDevice::Cpu;
|
||||||
|
bench::<CandleBackend>(&device);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,6 +52,8 @@ wgpu = ["burn-wgpu/default"]
|
||||||
|
|
||||||
tch = ["burn-tch"]
|
tch = ["burn-tch"]
|
||||||
|
|
||||||
|
candle = ["burn-candle"]
|
||||||
|
|
||||||
# Serialization formats
|
# Serialization formats
|
||||||
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
|
||||||
|
|
||||||
|
@ -72,6 +74,7 @@ burn-ndarray = { path = "../burn-ndarray", version = "0.10.0", optional = true,
|
||||||
burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true }
|
burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", optional = true }
|
||||||
burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true }
|
burn-wgpu = { path = "../burn-wgpu", version = "0.10.0", optional = true }
|
||||||
burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true }
|
burn-tch = { path = "../burn-tch", version = "0.10.0", optional = true }
|
||||||
|
burn-candle = { path = "../burn-candle", version = "0.10.0", optional = true }
|
||||||
|
|
||||||
derive-new = { workspace = true }
|
derive-new = { workspace = true }
|
||||||
libm = { workspace = true }
|
libm = { workspace = true }
|
||||||
|
|
|
@ -23,6 +23,18 @@ pub type WgpuBackend<G = wgpu::AutoGraphicsApi, F = f32, I = i32> = wgpu::WgpuBa
|
||||||
pub type WgpuAutodiffBackend<G = wgpu::AutoGraphicsApi, F = f32, I = i32> =
|
pub type WgpuAutodiffBackend<G = wgpu::AutoGraphicsApi, F = f32, I = i32> =
|
||||||
crate::autodiff::ADBackendDecorator<WgpuBackend<G, F, I>>;
|
crate::autodiff::ADBackendDecorator<WgpuBackend<G, F, I>>;
|
||||||
|
|
||||||
|
#[cfg(feature = "candle")]
|
||||||
|
/// Candle module.
|
||||||
|
pub use burn_candle as candle;
|
||||||
|
|
||||||
|
#[cfg(feature = "candle")]
|
||||||
|
/// A CandleBackend with a default type of f32/i64.
|
||||||
|
pub type CandleBackend = candle::CandleBackend<f32, i64>;
|
||||||
|
|
||||||
|
#[cfg(all(feature = "candle", feature = "autodiff"))]
|
||||||
|
/// A CandleBackend with autodiffing enabled.
|
||||||
|
pub type CandleAutodiffBackend = crate::autodiff::ADBackendDecorator<CandleBackend>;
|
||||||
|
|
||||||
#[cfg(feature = "tch")]
|
#[cfg(feature = "tch")]
|
||||||
/// Tch module.
|
/// Tch module.
|
||||||
pub use burn_tch as tch;
|
pub use burn_tch as tch;
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-wgpu.svg)](https://crates.io/crates/burn-wgpu)
|
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-wgpu.svg)](https://crates.io/crates/burn-wgpu)
|
||||||
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn-wgpu/blob/master/README.md)
|
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn-wgpu/blob/master/README.md)
|
||||||
|
|
||||||
This crate provides a WGPU backend for [Burn](https://github.com/burn-rs/burn) utilizing the
|
This crate provides a WGPU backend for [Burn](https://github.com/burn-rs/burn) using the
|
||||||
[wgpu](https://github.com/gfx-rs/wgpu).
|
[wgpu](https://github.com/gfx-rs/wgpu).
|
||||||
|
|
||||||
The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU.
|
The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU.
|
||||||
|
|
|
@ -45,6 +45,7 @@ ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"]
|
||||||
|
|
||||||
wgpu = ["burn-core/wgpu"]
|
wgpu = ["burn-core/wgpu"]
|
||||||
tch = ["burn-core/tch"]
|
tch = ["burn-core/tch"]
|
||||||
|
candle = ["burn-core/candle"]
|
||||||
|
|
||||||
# Experimental
|
# Experimental
|
||||||
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
|
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
|
||||||
|
|
Loading…
Reference in New Issue