mirror of https://github.com/tracel-ai/burn.git
Always derive Cube features from adapter (#1958)
This commit is contained in:
parent
fe0544b9ea
commit
0928a52eea
|
@ -43,15 +43,9 @@ impl Runtime for WgpuRuntime {
|
|||
|
||||
fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
|
||||
RUNTIME.client(device, move || {
|
||||
let (adapter, device_wgpu, queue, features) =
|
||||
let (adapter, device_wgpu, queue) =
|
||||
pollster::block_on(create_wgpu_setup::<AutoGraphicsApi>(device));
|
||||
create_client(
|
||||
adapter,
|
||||
device_wgpu,
|
||||
queue,
|
||||
features,
|
||||
RuntimeOptions::default(),
|
||||
)
|
||||
create_client(adapter, device_wgpu, queue, RuntimeOptions::default())
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -102,11 +96,10 @@ pub fn init_existing_device(
|
|||
adapter: Arc<wgpu::Adapter>,
|
||||
device: Arc<wgpu::Device>,
|
||||
queue: Arc<wgpu::Queue>,
|
||||
features: Arc<FeatureSet>,
|
||||
options: RuntimeOptions,
|
||||
) -> WgpuDevice {
|
||||
let device_id = WgpuDevice::Existing(device.as_ref().global_id());
|
||||
let client = create_client(adapter, device, queue, features, options);
|
||||
let client = create_client(adapter, device, queue, options);
|
||||
RUNTIME.register(&device_id, client);
|
||||
device_id
|
||||
}
|
||||
|
@ -119,39 +112,28 @@ pub fn init_sync<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) {
|
|||
|
||||
/// Like [`init_sync`], but async, necessary for wasm.
|
||||
pub async fn init_async<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) {
|
||||
let (adapter, device_wgpu, queue, features) = create_wgpu_setup::<G>(device).await;
|
||||
let client = create_client(adapter, device_wgpu, queue, features, options);
|
||||
let (adapter, device_wgpu, queue) = create_wgpu_setup::<G>(device).await;
|
||||
let client = create_client(adapter, device_wgpu, queue, options);
|
||||
RUNTIME.register(device, client)
|
||||
}
|
||||
|
||||
async fn create_wgpu_setup<G: GraphicsApi>(
|
||||
device: &WgpuDevice,
|
||||
) -> (
|
||||
Arc<wgpu::Adapter>,
|
||||
Arc<wgpu::Device>,
|
||||
Arc<wgpu::Queue>,
|
||||
Arc<FeatureSet>,
|
||||
) {
|
||||
let (device_wgpu, queue, adapter, features) = select_device::<G>(device).await;
|
||||
) -> (Arc<wgpu::Adapter>, Arc<wgpu::Device>, Arc<wgpu::Queue>) {
|
||||
let (device_wgpu, queue, adapter) = select_device::<G>(device).await;
|
||||
|
||||
log::info!(
|
||||
"Created wgpu compute server on device {:?} => {:?}",
|
||||
device,
|
||||
adapter.get_info()
|
||||
);
|
||||
(
|
||||
Arc::new(adapter),
|
||||
Arc::new(device_wgpu),
|
||||
Arc::new(queue),
|
||||
Arc::new(features),
|
||||
)
|
||||
(Arc::new(adapter), Arc::new(device_wgpu), Arc::new(queue))
|
||||
}
|
||||
|
||||
fn create_client(
|
||||
adapter: Arc<wgpu::Adapter>,
|
||||
device_wgpu: Arc<wgpu::Device>,
|
||||
queue: Arc<wgpu::Queue>,
|
||||
features: Arc<FeatureSet>,
|
||||
options: RuntimeOptions,
|
||||
) -> ComputeClient<
|
||||
WgpuServer<DynamicMemoryManagement<WgpuStorage>>,
|
||||
|
@ -170,24 +152,6 @@ fn create_client(
|
|||
let channel = MutexComputeChannel::new(server);
|
||||
let tuner_device_id = tuner_device_id(adapter.get_info());
|
||||
|
||||
ComputeClient::new(
|
||||
channel,
|
||||
Arc::new(RwLock::new(Tuner::new("wgpu", &tuner_device_id))),
|
||||
features,
|
||||
)
|
||||
}
|
||||
|
||||
/// Select the wgpu device and queue based on the provided [device](WgpuDevice).
|
||||
pub async fn select_device<G: GraphicsApi>(
|
||||
device: &WgpuDevice,
|
||||
) -> (wgpu::Device, wgpu::Queue, wgpu::Adapter, FeatureSet) {
|
||||
#[cfg(target_family = "wasm")]
|
||||
let adapter = select_adapter::<G>(device).await;
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
let adapter = select_adapter::<G>(device);
|
||||
|
||||
let limits = adapter.limits();
|
||||
let features = adapter.features();
|
||||
let mut features_cube = FeatureSet::default();
|
||||
|
||||
|
@ -195,11 +159,29 @@ pub async fn select_device<G: GraphicsApi>(
|
|||
features_cube.register(Feature::Subcube);
|
||||
}
|
||||
|
||||
ComputeClient::new(
|
||||
channel,
|
||||
Arc::new(RwLock::new(Tuner::new("wgpu", &tuner_device_id))),
|
||||
Arc::new(features_cube),
|
||||
)
|
||||
}
|
||||
|
||||
/// Select the wgpu device and queue based on the provided [device](WgpuDevice).
|
||||
pub async fn select_device<G: GraphicsApi>(
|
||||
device: &WgpuDevice,
|
||||
) -> (wgpu::Device, wgpu::Queue, wgpu::Adapter) {
|
||||
#[cfg(target_family = "wasm")]
|
||||
let adapter = select_adapter::<G>(device).await;
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
let adapter = select_adapter::<G>(device);
|
||||
let limits = adapter.limits();
|
||||
|
||||
let (device, queue) = adapter
|
||||
.request_device(
|
||||
&DeviceDescriptor {
|
||||
label: None,
|
||||
required_features: features,
|
||||
required_features: adapter.features(),
|
||||
required_limits: limits,
|
||||
},
|
||||
None,
|
||||
|
@ -214,7 +196,7 @@ pub async fn select_device<G: GraphicsApi>(
|
|||
})
|
||||
.unwrap();
|
||||
|
||||
(device, queue, adapter, features_cube)
|
||||
(device, queue, adapter)
|
||||
}
|
||||
|
||||
fn tuner_device_id(info: AdapterInfo) -> String {
|
||||
|
|
Loading…
Reference in New Issue