mirror of https://github.com/tracel-ai/burn.git
42f39f16b3 | ||
---|---|---|
.. | ||
src | ||
Cargo.toml | ||
LICENSE-APACHE | ||
LICENSE-MIT | ||
README.md |
README.md
Burn WGPU Backend
Burn WGPU backend
This crate provides a WGPU backend for Burn using the wgpu.
The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU.
Usage Example
#[cfg(feature = "wgpu")]
mod wgpu {
use burn_autodiff::Autodiff;
use burn_wgpu::{Wgpu, WgpuDevice};
use mnist::training;
pub fn run() {
let device = WgpuDevice::default();
training::run::<Autodiff<Wgpu<f32, i32>>>(device);
}
}
Configuration
You can set BURN_WGPU_MAX_TASKS
to a positive integer that determines how many computing tasks are
submitted in batches to the graphics API.
Alternative SPIR-V backend
When targeting Vulkan, the spirv
feature flag can be enabled to enable the SPIR-V compiler backend,
which performs significantly better than WGSL. This is especially true for matrix multiplication,
where SPIR-V can make use of TensorCores and run at f16
precision. This isn't currently supported
by WGSL.
The compiler can also be selected at runtime by setting the corresponding generic parameter to
either SpirV
or Wgsl
.
Platform Support
Option | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |
---|---|---|---|---|---|---|---|---|
Metal | No | Yes | No | Yes | No | No | Yes | No |
Vulkan | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No |
OpenGL | No | Yes | Yes | Yes | Yes | Yes | Yes | No |
WebGpu | No | Yes | No | No | No | No | No | Yes |
Dx11/Dx12 | No | Yes | No | No | Yes | No | No | No |