Fix import and type redefinitions in mnist example crate (#1100)

* Remove the double import for WgpuDevice
* Prioritize wgpu backend over the default ndarray when wgpu feature is set

This fixes `cargo bench --festures wgpu` as `--no-default-features` cannot be
used.
This commit is contained in:
Sylvain Benner 2024-01-02 12:47:44 -05:00 committed by GitHub
parent 40ec289a92
commit a4de93a39f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 3 deletions

View File

@ -1,6 +1,4 @@
use crate::model::Model;
#[cfg(feature = "wgpu")]
use burn::backend::wgpu::WgpuDevice;
use burn::module::Module;
use burn::record::BinBytesRecorder;
use burn::record::FullPrecisionSettings;
@ -12,7 +10,7 @@ use burn::backend::wgpu::{compute::init_async, AutoGraphicsApi, Wgpu, WgpuDevice
#[cfg(feature = "wgpu")]
pub type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
#[cfg(feature = "ndarray")]
#[cfg(all(feature = "ndarray", not(feature = "wgpu")))]
pub type Backend = burn::backend::ndarray::NdArray<f32>;
static STATE_ENCODED: &[u8] = include_bytes!("../model.bin");