Remove GraphicsAPI generic for WgpuRuntime (#1888)

This commit is contained in:
Arthur Brussee 2024-06-17 14:04:25 +01:00 committed by GitHub
parent eead748e90
commit ac9f942a46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 78 additions and 63 deletions

View File

@ -62,14 +62,9 @@ macro_rules! bench_on_backend {
#[cfg(feature = "wgpu")]
{
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn::backend::wgpu::{Wgpu, WgpuDevice};
bench::<Wgpu<AutoGraphicsApi, f32, i32>>(
&WgpuDevice::default(),
feature_name,
url,
token,
);
bench::<Wgpu<f32, i32>>(&WgpuDevice::default(), feature_name, url, token);
}
#[cfg(feature = "tch-gpu")]

View File

@ -198,7 +198,7 @@ the raw `WgpuBackend` type.
```rust, ignore
/// Implement our custom backend trait for the existing backend `WgpuBackend`.
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime<G>, F, I> {
impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
fn fused_matmul_add_relu<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,

View File

@ -11,13 +11,13 @@ entrypoint of our program, namely the `main` function defined in `src/main.rs`.
#
use crate::{model::ModelConfig, training::TrainingConfig};
use burn::{
backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
backend::{Autodiff, Wgpu},
# data::dataset::Dataset,
optim::AdamConfig,
};
fn main() {
type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;
let device = burn::backend::wgpu::WgpuDevice::default();
@ -32,10 +32,9 @@ fn main() {
In this example, we use the `Wgpu` backend which is compatible with any operating system and will
use the GPU. For other options, see the Burn README. This backend type takes the graphics API, the
float type and the int type as generic arguments that will be used during the training. By leaving
the graphics API as `AutoGraphicsApi`, it should automatically use an API available on your machine.
The autodiff backend is simply the same backend, wrapped within the `Autodiff` struct which imparts
differentiability to any backend.
float type and the int type as generic arguments that will be used during the training. The autodiff
backend is simply the same backend, wrapped within the `Autodiff` struct which imparts differentiability \
to any backend.
We call the `train` function defined earlier with a directory for artifacts, the configuration of
the model (the number of digit classes is 10 and the hidden dimension is 512), the optimizer

View File

@ -56,13 +56,13 @@ Add the call to `infer` to the `main.rs` file after the `train` function call:
#
# use crate::{model::ModelConfig, training::TrainingConfig};
# use burn::{
# backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
# backend::{Autodiff, Wgpu},
# data::dataset::Dataset,
# optim::AdamConfig,
# };
#
# fn main() {
# type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
# type MyBackend = Wgpu<f32, i32>;
# type MyAutodiffBackend = Autodiff<MyBackend>;
#
# let device = burn::backend::wgpu::WgpuDevice::default();

View File

@ -16,12 +16,12 @@ The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU.
#[cfg(feature = "wgpu")]
mod wgpu {
use burn_autodiff::Autodiff;
use burn_wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn_wgpu::{Wgpu, WgpuDevice};
use mnist::training;
pub fn run() {
let device = WgpuDevice::default();
training::run::<Autodiff<Wgpu<AutoGraphicsApi, f32, i32>>>(device);
training::run::<Autodiff<Wgpu<f32, i32>>>(device);
}
}
```

View File

@ -37,6 +37,21 @@ pub use burn_jit::{tensor::JitTensor, JitBackend};
/// - [Metal] on Apple hardware.
/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
///
/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
/// you have to manually initialize the runtime. For example:
///
/// ```rust, ignore
/// fn custom_init() {
/// let device = Default::default();
/// burn::backend::wgpu::init_sync::<burn::backend::wgpu::Vulkan>(
/// &device,
/// Default::default(),
/// );
/// }
/// ```
/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
/// It's also possible to use an existing wgpu device, by using `init_existing_device`.
///
/// # Notes
///
/// This version of the [wgpu] backend uses [burn_fusion] to compile and optimize streams of tensor
@ -44,8 +59,7 @@ pub use burn_jit::{tensor::JitTensor, JitBackend};
///
/// You can disable the `fusion` feature flag to remove that functionality, which might be
/// necessary on `wasm` for now.
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
burn_fusion::Fusion<JitBackend<WgpuRuntime<G>, F, I>>;
pub type Wgpu<F = f32, I = i32> = burn_fusion::Fusion<JitBackend<WgpuRuntime, F, I>>;
#[cfg(not(feature = "fusion"))]
/// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders.
@ -57,6 +71,21 @@ pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
/// - [Metal] on Apple hardware.
/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
///
/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
/// you have to manually initialize the runtime. For example:
///
/// ```rust, ignore
/// fn custom_init() {
/// let device = Default::default();
/// burn::backend::wgpu::init_sync::<burn::backend::wgpu::Vulkan>(
/// &device,
/// Default::default(),
/// );
/// }
/// ```
/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
/// It's also possible to use an existing wgpu device, by using `init_existing_device`.
///
/// # Notes
///
/// This version of the [wgpu] backend doesn't use [burn_fusion] to compile and optimize streams of tensor
@ -64,13 +93,13 @@ pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
///
/// You can enable the `fusion` feature flag to add that functionality, which might improve
/// performance.
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> = JitBackend<WgpuRuntime<G>, F, I>;
pub type Wgpu<F = f32, I = i32> = JitBackend<WgpuRuntime, F, I>;
#[cfg(test)]
mod tests {
use super::*;
pub type TestRuntime = crate::WgpuRuntime<AutoGraphicsApi>;
pub type TestRuntime = crate::WgpuRuntime;
burn_jit::testgen_all!();
burn_cube::testgen_all!();

View File

@ -1,7 +1,7 @@
use crate::{
compiler::wgsl,
compute::{WgpuServer, WgpuStorage},
GraphicsApi, WgpuDevice,
AutoGraphicsApi, GraphicsApi, WgpuDevice,
};
use alloc::sync::Arc;
use burn_common::stub::RwLock;
@ -15,21 +15,16 @@ use burn_compute::{
use burn_cube::Runtime;
use burn_jit::JitRuntime;
use burn_tensor::backend::{DeviceId, DeviceOps};
use std::{
marker::PhantomData,
sync::atomic::{AtomicBool, Ordering},
};
use std::sync::atomic::{AtomicBool, Ordering};
use wgpu::{AdapterInfo, DeviceDescriptor};
/// Runtime that uses the [wgpu] crate with the wgsl compiler.
///
/// The [graphics api](GraphicsApi) type is passed as generic.
/// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend.
/// For advanced configuration, use [`init_sync`] to pass in runtime options or to select a
/// specific graphics API.
#[derive(Debug)]
pub struct WgpuRuntime<G: GraphicsApi> {
_g: PhantomData<G>,
}
pub struct WgpuRuntime {}
impl<G: GraphicsApi> JitRuntime for WgpuRuntime<G> {
impl JitRuntime for WgpuRuntime {
type JitDevice = WgpuDevice;
type JitServer = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
}
@ -42,7 +37,7 @@ type Server = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
static SUBGROUP: AtomicBool = AtomicBool::new(false);
impl<G: GraphicsApi> Runtime for WgpuRuntime<G> {
impl Runtime for WgpuRuntime {
type Compiler = wgsl::WgslCompiler;
type Server = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
@ -51,7 +46,8 @@ impl<G: GraphicsApi> Runtime for WgpuRuntime<G> {
fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
RUNTIME.client(device, move || {
let (adapter, device_wgpu, queue) = pollster::block_on(create_wgpu_setup::<G>(device));
let (adapter, device_wgpu, queue) =
pollster::block_on(create_wgpu_setup::<AutoGraphicsApi>(device));
create_client(adapter, device_wgpu, queue, RuntimeOptions::default())
})
}
@ -125,14 +121,13 @@ pub fn init_existing_device(
device_id
}
/// Init the client sync, useful to configure the runtime options.
/// Initialize a client on the given device with the given options. This function is useful to configure the runtime options
/// or to pick a different graphics API. On wasm, it is necessary to use [`init_async`] instead.
pub fn init_sync<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) {
let (adapter, device_wgpu, queue) = pollster::block_on(create_wgpu_setup::<G>(device));
let client = create_client(adapter, device_wgpu, queue, options);
RUNTIME.register(device, client)
pollster::block_on(init_async::<G>(device, options));
}
/// Init the client async, necessary for wasm.
/// Like [`init_sync`], but async, necessary for wasm.
pub async fn init_async<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) {
let (adapter, device_wgpu, queue) = create_wgpu_setup::<G>(device).await;
let client = create_client(adapter, device_wgpu, queue, options);

View File

@ -1,5 +1,5 @@
use burn::{
backend::wgpu::{AutoGraphicsApi, WgpuRuntime},
backend::wgpu::WgpuRuntime,
tensor::{Distribution, Tensor},
};
use custom_wgpu_kernel::{
@ -71,7 +71,7 @@ fn autodiff<B: AutodiffBackend>(device: &B::Device) {
}
fn main() {
type MyBackend = burn::backend::wgpu::JitBackend<WgpuRuntime<AutoGraphicsApi>, f32, i32>;
type MyBackend = burn::backend::wgpu::JitBackend<WgpuRuntime, f32, i32>;
type MyAutodiffBackend = burn::backend::Autodiff<MyBackend>;
let device = Default::default();
inference::<MyBackend>(&device);

View File

@ -9,15 +9,12 @@ use burn::{
ops::{broadcast_shape, Backward, Ops, OpsKind},
Autodiff, NodeID,
},
wgpu::{FloatElement, GraphicsApi, IntElement, JitBackend, WgpuRuntime},
wgpu::{FloatElement, IntElement, JitBackend, WgpuRuntime},
},
tensor::Shape,
};
impl<G: GraphicsApi, F: FloatElement, I: IntElement> AutodiffBackend
for Autodiff<JitBackend<WgpuRuntime<G>, F, I>>
{
}
impl<F: FloatElement, I: IntElement> AutodiffBackend for Autodiff<JitBackend<WgpuRuntime, F, I>> {}
// Implement our custom backend trait for any backend that also implements our custom backend trait.
//

View File

@ -3,8 +3,8 @@ use crate::FloatTensor;
use super::Backend;
use burn::{
backend::wgpu::{
build_info, into_contiguous, kernel_wgsl, CubeCount, CubeDim, FloatElement, GraphicsApi,
IntElement, JitBackend, JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime,
build_info, into_contiguous, kernel_wgsl, CubeCount, CubeDim, FloatElement, IntElement,
JitBackend, JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime,
},
tensor::Shape,
};
@ -36,7 +36,7 @@ impl<E: FloatElement> KernelSource for FusedMatmulAddRelu<E> {
}
/// Implement our custom backend trait for the existing backend `WgpuBackend`.
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime<G>, F, I> {
impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
fn fused_matmul_add_relu<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,

View File

@ -5,13 +5,13 @@ mod training;
use crate::{model::ModelConfig, training::TrainingConfig};
use burn::{
backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
backend::{Autodiff, Wgpu},
data::dataset::Dataset,
optim::AdamConfig,
};
fn main() {
type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;
let device = burn::backend::wgpu::WgpuDevice::default();

View File

@ -34,7 +34,7 @@ pub enum ModelType {
WithNdArrayBackend(Model<NdArray<f32>>),
/// The model is loaded to the Wgpu backend
WithWgpuBackend(Model<Wgpu<AutoGraphicsApi, f32, i32>>),
WithWgpuBackend(Model<Wgpu<f32, i32>>),
}
/// The image is 224x224 pixels with 3 channels (RGB)

View File

@ -8,7 +8,7 @@ use burn::{
use burn::backend::wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuDevice};
#[cfg(feature = "wgpu")]
pub type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
pub type Backend = Wgpu<f32, i32>;
#[cfg(all(feature = "ndarray", not(feature = "wgpu")))]
pub type Backend = burn::backend::ndarray::NdArray<f32>;

View File

@ -74,10 +74,10 @@ mod tch_cpu {
#[cfg(feature = "wgpu")]
mod wgpu {
use crate::{launch, ElemType};
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn::backend::wgpu::{Wgpu, WgpuDevice};
pub fn run() {
launch::<Wgpu<AutoGraphicsApi, ElemType, i32>>(WgpuDevice::default());
launch::<Wgpu<ElemType, i32>>(WgpuDevice::default());
}
}

View File

@ -85,12 +85,12 @@ mod tch_cpu {
mod wgpu {
use crate::{launch, ElemType};
use burn::backend::{
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
wgpu::{Wgpu, WgpuDevice},
Autodiff,
};
pub fn run() {
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(vec![WgpuDevice::default()]);
launch::<Autodiff<Wgpu<ElemType, i32>>>(vec![WgpuDevice::default()]);
}
}

View File

@ -82,14 +82,14 @@ mod tch_cpu {
#[cfg(feature = "wgpu")]
mod wgpu {
use burn::backend::{
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
wgpu::{Wgpu, WgpuDevice},
Autodiff,
};
use crate::{launch, ElemType};
pub fn run() {
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(WgpuDevice::default());
launch::<Autodiff<Wgpu<ElemType, i32>>>(WgpuDevice::default());
}
}

View File

@ -81,14 +81,14 @@ mod tch_cpu {
#[cfg(feature = "wgpu")]
mod wgpu {
use burn::backend::{
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
wgpu::{Wgpu, WgpuDevice},
Autodiff,
};
use crate::{launch, ElemType};
pub fn run() {
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(vec![WgpuDevice::default()]);
launch::<Autodiff<Wgpu<ElemType, i32>>>(vec![WgpuDevice::default()]);
}
}