Refactor execute_dynamic with Execution struct (#1550)

This commit is contained in:
Louis Fortier-Dubois 2024-03-28 17:27:48 -04:00 committed by GitHub
parent efc3b2d243
commit edcd92f13d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 304 additions and 441 deletions

View File

@ -220,7 +220,7 @@ pub(crate) fn run_backend_comparison_benchmarks(
let status = run_cargo("bench", &args).unwrap();
if !status.success() {
println!(
"Benchmark {} didn't ran successfully on the backend {}",
"Benchmark {} didn't run successfully on the backend {}",
bench_str, backend_str
);
continue;

View File

@ -1,4 +1,3 @@
# Burn JIT Backend
Generic backend that can be compiled just-in-time (JIT) to any shader language target
In progress: At the moment, only WGSL compilation is supported, and some kernels still rely on pure WGSL

View File

@ -34,27 +34,8 @@ pub fn execute_static<R, K, E>(
R: Runtime,
E: JitElement,
{
execute_static_::<R, K, E, E, E>(inputs, outputs, scalar_elems, None, None, launch, client)
}
fn execute_static_<R, K, E1, E2, E3>(
inputs: &[EagerHandle<R>],
outputs: &[EagerHandle<R>],
scalars_1: Option<&[E1]>,
scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>,
launch: WorkgroupLaunch,
client: ComputeClient<R::Server, R::Channel>,
) where
K: StaticKernelSource + 'static,
R: Runtime,
E1: JitElement,
E2: JitElement,
E3: JitElement,
{
let settings = execute_settings(
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
);
let settings =
execute_settings::<R, E, E, E>(inputs, outputs, scalar_elems, None, None, launch, &client);
let mut handles = settings.handles_tensors;
let workgroup = settings.workgroup;
@ -64,6 +45,7 @@ fn execute_static_<R, K, E1, E2, E3>(
}
let kernel = Box::new(StaticKernel::<K>::new(workgroup));
client.execute(kernel, &handles);
}
@ -128,7 +110,7 @@ where
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: WorkgroupLaunch) {
execute_dynamic_::<R, K, f32, f32, f32>(
execute_dynamic::<R, K, f32, f32, f32>(
self.inputs,
self.outputs,
None,
@ -163,7 +145,7 @@ where
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: WorkgroupLaunch) {
execute_dynamic_::<R, K, E, f32, f32>(
execute_dynamic::<R, K, E, f32, f32>(
self.inputs,
self.outputs,
Some(self.scalars.0),
@ -203,7 +185,7 @@ where
K: DynamicKernelSource + 'static,
R: Runtime,
{
execute_dynamic_::<R, K, E1, E2, f32>(
execute_dynamic::<R, K, E1, E2, f32>(
self.inputs,
self.outputs,
Some(self.scalars.0),
@ -227,7 +209,7 @@ where
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: WorkgroupLaunch) {
execute_dynamic_::<R, K, E1, E2, E3>(
execute_dynamic::<R, K, E1, E2, E3>(
self.inputs,
self.outputs,
Some(self.scalars.0),
@ -240,33 +222,8 @@ where
}
}
/// Execute a dynamic kernel.
pub fn execute_dynamic<R, K, E>(
inputs: &[EagerHandle<R>],
outputs: &[EagerHandle<R>],
scalar_elems: Option<&[E]>,
kernel: K,
launch: WorkgroupLaunch,
client: ComputeClient<R::Server, R::Channel>,
) where
K: DynamicKernelSource + 'static,
R: Runtime,
E: JitElement,
{
execute_dynamic_::<R, K, E, E, E>(
inputs,
outputs,
scalar_elems,
None,
None,
kernel,
launch,
client,
)
}
#[allow(clippy::too_many_arguments)]
fn execute_dynamic_<R, K, E1, E2, E3>(
fn execute_dynamic<R, K, E1, E2, E3>(
inputs: &[EagerHandle<R>],
outputs: &[EagerHandle<R>],
scalars_1: Option<&[E1]>,

View File

@ -2,7 +2,7 @@ use std::{any::TypeId, marker::PhantomData};
use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
gpu::{gpu, Scope, Variable, Visibility},
@ -21,7 +21,7 @@ pub fn cast<R: Runtime, EI: JitElement, EO: JitElement, const D: usize>(
return JitTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle);
}
let kernel = CastEagerKernel::new();
let kernel = CastEagerKernel::<R, EI, EO>::new();
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
let output = JitTensor::new(
@ -31,22 +31,18 @@ pub fn cast<R: Runtime, EI: JitElement, EO: JitElement, const D: usize>(
buffer,
);
execute_dynamic::<R, CastEagerKernel<R, EI, EO>, u32>(
&[EagerHandle::new(
Execution::start(kernel, tensor.client)
.inputs(&[EagerHandle::<R>::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -2,7 +2,7 @@ use std::marker::PhantomData;
use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
@ -20,7 +20,7 @@ use crate::{
pub fn bool_cast<R: Runtime, EO: JitElement, const D: usize>(
tensor: JitTensor<R, u32, D>,
) -> JitTensor<R, EO, D> {
let kernel = BoolCastEagerKernel::new();
let kernel = BoolCastEagerKernel::<R, EO>::new();
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
let output = JitTensor::new(
@ -30,22 +30,18 @@ pub fn bool_cast<R: Runtime, EO: JitElement, const D: usize>(
buffer,
);
execute_dynamic::<R, BoolCastEagerKernel<R, EO>, u32>(
&[EagerHandle::new(
Execution::start(kernel, tensor.client)
.inputs(&[EagerHandle::<R>::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -2,7 +2,7 @@ use std::marker::PhantomData;
use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
gpu::{gpu, Elem, IndexOffsetGlobalWithLayout, Scope, Variable, Visibility},
@ -18,9 +18,9 @@ pub(crate) struct IntoContiguousShader {
}
#[derive(new)]
pub(crate) struct IntoContiguousEagerKernel<R: Runtime, EO: JitElement> {
pub(crate) struct IntoContiguousEagerKernel<R: Runtime, E: JitElement> {
_runtime: PhantomData<R>,
_elem_out: PhantomData<EO>,
_elem_out: PhantomData<E>,
}
/// Make a jit tensor contiguous.
@ -31,7 +31,7 @@ pub fn into_contiguous<R: Runtime, E: JitElement, const D: usize>(
return tensor;
}
let kernel = IntoContiguousEagerKernel::new();
let kernel = IntoContiguousEagerKernel::<R, E>::new();
let num_elems = tensor.shape.num_elements();
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
let output = JitTensor::new(
@ -41,22 +41,18 @@ pub fn into_contiguous<R: Runtime, E: JitElement, const D: usize>(
buffer,
);
execute_dynamic::<R, IntoContiguousEagerKernel<R, E>, u32>(
&[EagerHandle::new(
Execution::start(kernel, tensor.client)
.inputs(&[EagerHandle::<R>::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -1,13 +1,13 @@
use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions},
ElementConversion, Shape,
Shape,
};
use std::marker::PhantomData;
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
@ -17,7 +17,7 @@ use crate::{
reshape,
},
tensor::JitTensor,
Runtime, RuntimeInt,
Runtime,
};
#[derive(new)]
@ -335,32 +335,29 @@ pub(crate) fn conv2d<R: Runtime, E: JitElement>(
}
};
let kernel = Conv2dEagerKernel::new();
let kernel = Conv2dEagerKernel::<R, E>::new();
execute_dynamic::<R, Conv2dEagerKernel<R, E>, RuntimeInt<R>>(
&[
EagerHandle::new(&input.handle, &input.strides, &input.shape.dims),
Execution::start(kernel, input.client)
.inputs(&[
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
EagerHandle::new(&weight.handle, &weight.strides, &weight.shape.dims),
EagerHandle::new(&bias.handle, &bias.strides, &bias.shape.dims),
],
&[EagerHandle::new(
])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&[
(options.stride[0] as u32).elem(),
(options.stride[1] as u32).elem(),
(options.dilation[0] as u32).elem(),
(options.dilation[1] as u32).elem(),
(options.padding[0] as u32).elem(),
(options.padding[1] as u32).elem(),
(options.groups as u32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
input.client,
);
)])
.with_scalars(&[
options.stride[0] as u32,
options.stride[1] as u32,
options.dilation[0] as u32,
options.dilation[1] as u32,
options.padding[0] as u32,
options.padding[1] as u32,
options.groups as u32,
])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -15,7 +15,7 @@ use crate::{
tensor::JitTensor,
Compiler, Runtime,
};
use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape};
use burn_tensor::{ops::ConvTransposeOptions, Element, Shape};
#[derive(new)]
struct Conv2dTransposeEagerKernel<R, E> {
@ -399,13 +399,13 @@ pub(crate) fn conv_transpose2d<R: Runtime, E: JitElement + Element>(
&output.shape.dims,
)])
.with_scalars(&[
(options.stride[0] as u32).elem::<u32>(),
(options.stride[1] as u32).elem(),
(options.dilation[0] as u32).elem(),
(options.dilation[1] as u32).elem(),
(options.padding[0] as u32).elem(),
(options.padding[1] as u32).elem(),
(options.groups as u32).elem(),
options.stride[0] as u32,
options.stride[1] as u32,
options.dilation[0] as u32,
options.dilation[1] as u32,
options.padding[0] as u32,
options.padding[1] as u32,
options.groups as u32,
])
.execute(WorkgroupLaunch::Output { pos: 0 });

View File

@ -1,14 +1,14 @@
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{DynamicKernelSource, SourceTemplate},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime, RuntimeInt,
Runtime,
};
use burn_tensor::ElementConversion;
use std::marker::PhantomData;
@ -128,30 +128,27 @@ pub(crate) fn flip_on_output<R: Runtime, E: JitElement, const D: usize>(
output: JitTensor<R, E, D>,
indices: &[usize],
) -> JitTensor<R, E, D> {
let mut scalars = Vec::with_capacity(D);
let mut scalars: Vec<u32> = Vec::with_capacity(D);
for i in 0..D {
scalars.push((indices.contains(&i) as u32).elem());
}
let kernel = FlipEagerKernel::new(D);
let kernel = FlipEagerKernel::<R, E>::new(D);
execute_dynamic::<R, FlipEagerKernel<R, E>, RuntimeInt<R>>(
&[EagerHandle::new(
Execution::start(kernel, tensor.client)
.inputs(&[EagerHandle::<R>::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&scalars),
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
)])
.with_scalars(&scalars)
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -1,8 +1,9 @@
use crate::codegen::dialect::gpu::{gpu, Elem, Scope, Variable};
use crate::codegen::Execution;
use crate::{
codegen::{
dialect::gpu, execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler,
EagerHandle, InputInfo, OutputInfo, WorkgroupLaunch,
dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{self, DynamicKernelSource, SourceTemplate},
@ -130,23 +131,19 @@ pub(crate) fn gather<R: Runtime, E: JitElement, I: JitElement, const D: usize>(
) -> JitTensor<R, E, D> {
let shape_output = indices.shape.clone();
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
let kernel = GatherEagerKernel::new(dim);
let kernel = GatherEagerKernel::<R, E>::new(dim);
execute_dynamic::<R, GatherEagerKernel<R, E>, E>(
&[
EagerHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
Execution::start(kernel, tensor.client)
.inputs(&[
EagerHandle::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
EagerHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
],
&[EagerHandle::new(
])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -1,7 +1,7 @@
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
@ -126,23 +126,20 @@ pub(crate) fn repeat<R: Runtime, E: JitElement, const D1: usize>(
handle,
);
let kernel = RepeatEagerKernel::new(dim, D1);
let kernel = RepeatEagerKernel::<R, E>::new(dim, D1);
execute_dynamic::<R, RepeatEagerKernel<R, E>, E>(
&[EagerHandle::new(
Execution::start(kernel, input.client)
.inputs(&[EagerHandle::<R>::new(
&input.handle,
&input.strides,
&input.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
input.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -1,9 +1,10 @@
use crate::codegen::dialect::gpu::{gpu, Branch, Elem, Scope, Variable};
use crate::codegen::Execution;
use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT};
use crate::{
codegen::{
dialect::gpu, execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler,
EagerHandle, InputInfo, WorkgroupLaunch,
dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
InputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{self, DynamicKernelSource, SourceTemplate},
@ -192,7 +193,7 @@ pub(crate) fn scatter<R: Runtime, E: JitElement, I: JitElement, const D: usize>(
false => tensor.copy(),
};
let kernel = ScatterEagerKernel::new(dim);
let kernel = ScatterEagerKernel::<R, E>::new(dim);
let mut strides = [0; D];
let mut current = 1;
let mut num_elems_per_workgroup = 1;
@ -209,22 +210,19 @@ pub(crate) fn scatter<R: Runtime, E: JitElement, I: JitElement, const D: usize>(
current *= val;
num_elems_per_workgroup *= tensor.shape.dims[index];
});
// Fake strides of the virtual output where the strides of dim is hardcoded to one.
indices.strides = strides;
let workgroup = elemwise_workgroup(num_elems_per_workgroup, WORKGROUP_DEFAULT);
execute_dynamic::<R, ScatterEagerKernel<R, E>, E>(
&[
EagerHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
Execution::start(kernel, indices.client)
.inputs(&[
EagerHandle::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
EagerHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),
],
&[],
None,
kernel,
WorkgroupLaunch::Custom(workgroup),
indices.client,
);
])
.execute(WorkgroupLaunch::Custom(workgroup));
tensor
}

View File

@ -1,7 +1,7 @@
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
@ -125,27 +125,23 @@ pub(crate) fn select<R: Runtime, E: JitElement, I: JitElement, const D: usize>(
shape_output.dims[dim] = indices.shape.dims[0];
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
let kernel = SelectEagerKernel::new(dim);
let kernel = SelectEagerKernel::<R, E>::new(dim);
execute_dynamic::<R, SelectEagerKernel<R, E>, E>(
&[
EagerHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
Execution::start(kernel, tensor.client)
.inputs(&[
EagerHandle::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
// This is a current hacks because the info buffer that contains the strides and shapes is
// hardcoded to only contains information about tensors of the same rank. However, since
// we don't rely on the shape and stride of the indices tensors, it doesn't matter
// which value we put, it just needs to be of the same rank.
EagerHandle::new(&indices.handle, &[1; D], &[1; D]),
],
&[EagerHandle::new(
])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -1,7 +1,7 @@
use crate::{
codegen::{
dialect::gpu::{gpu, Branch, Elem, Item, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
InputInfo, WorkgroupLaunch,
},
element::JitElement,
@ -205,23 +205,18 @@ pub(crate) fn select_assign<R: Runtime, E: JitElement, I: JitElement, const D: u
num_elems_per_workgroup *= tensor.shape.dims[index];
});
let kernel = SelectAssignEagerKernel::new(dim);
let kernel = SelectAssignEagerKernel::<R, E>::new(dim);
let workgroup = elemwise_workgroup(num_elems_per_workgroup, WORKGROUP_DEFAULT);
execute_dynamic::<R, SelectAssignEagerKernel<R, E>, E>(
&[
EagerHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
Execution::start(kernel, indices.client)
.inputs(&[
EagerHandle::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),
// We use the custom strides here instead of the shape, since we don't use it in the
// kernel, but we need to put the right number of dimensions (rank).
EagerHandle::new(&indices.handle, &strides, &strides),
],
&[],
None,
kernel,
WorkgroupLaunch::Custom(workgroup),
indices.client,
);
])
.execute(WorkgroupLaunch::Custom(workgroup));
tensor
}

View File

@ -1,14 +1,14 @@
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{DynamicKernelSource, SourceTemplate},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime, RuntimeInt,
Runtime,
};
use burn_tensor::{ElementConversion, Shape};
use std::{marker::PhantomData, ops::Range};
@ -125,31 +125,28 @@ pub(crate) fn slice_on_output<R: Runtime, E: JitElement, const D1: usize, const
output: JitTensor<R, E, D1>,
indices: [Range<usize>; D2],
) -> JitTensor<R, E, D1> {
let mut scalars = Vec::with_capacity(D1);
let mut scalars: Vec<i32> = Vec::with_capacity(D1);
for i in 0..D1 {
let start = indices.get(i).map(|index| index.start).unwrap_or(0);
scalars.push((start as i32).elem());
}
let kernel = SliceEagerKernel::new(D1);
let kernel = SliceEagerKernel::<R, E>::new(D1);
execute_dynamic::<R, SliceEagerKernel<R, E>, RuntimeInt<R>>(
&[EagerHandle::new(
Execution::start(kernel, tensor.client)
.inputs(&[EagerHandle::<R>::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&scalars),
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
)])
.with_scalars(&scalars)
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -1,13 +1,13 @@
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
InputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{DynamicKernelSource, SourceTemplate},
tensor::JitTensor,
Runtime, RuntimeInt,
Runtime,
};
use burn_tensor::ElementConversion;
use std::{marker::PhantomData, ops::Range};
@ -130,26 +130,22 @@ pub(crate) fn slice_assign<R: Runtime, E: JitElement, const D1: usize, const D2:
true => tensor,
false => tensor.copy(),
};
let mut scalars = Vec::with_capacity(D1);
let mut scalars: Vec<i32> = Vec::with_capacity(D1);
for i in 0..D1 {
let start = indices.get(i).map(|index| index.start).unwrap_or(0);
scalars.push((start as i32).elem());
}
let kernel = SliceAssignEagerKernel::new(D1);
let kernel = SliceAssignEagerKernel::<R, E>::new(D1);
execute_dynamic::<R, SliceAssignEagerKernel<R, E>, RuntimeInt<R>>(
&[
EagerHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
Execution::start(kernel, value.client)
.inputs(&[
EagerHandle::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),
],
&[],
Some(&scalars),
kernel,
WorkgroupLaunch::Input { pos: 0 },
value.client,
);
])
.with_scalars(&scalars)
.execute(WorkgroupLaunch::Input { pos: 0 });
tensor
}

View File

@ -2,7 +2,7 @@ use std::marker::PhantomData;
use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
gpu::{gpu, Elem, Scope, Variable, Visibility},
@ -406,24 +406,20 @@ pub(crate) fn interpolate_bicubic_launch<R: Runtime, E: JitElement>(
input: JitTensor<R, E, 4>,
output: JitTensor<R, E, 4>,
) -> JitTensor<R, E, 4> {
let kernel = InterpolateBicubicEagerKernel::new();
let kernel = InterpolateBicubicEagerKernel::<R, E>::new();
execute_dynamic::<R, InterpolateBicubicEagerKernel<R, E>, u32>(
&[EagerHandle::new(
Execution::start(kernel, input.client)
.inputs(&[EagerHandle::<R>::new(
&input.handle,
&input.strides,
&input.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
input.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -2,7 +2,7 @@ use std::marker::PhantomData;
use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
gpu::{gpu, Elem, Scope, Variable, Visibility},
@ -226,24 +226,20 @@ pub(crate) fn interpolate_bilinear_launch<R: Runtime, E: JitElement>(
input: JitTensor<R, E, 4>,
output: JitTensor<R, E, 4>,
) -> JitTensor<R, E, 4> {
let kernel = InterpolateBilinearEagerKernel::new();
let kernel = InterpolateBilinearEagerKernel::<R, E>::new();
execute_dynamic::<R, InterpolateBilinearEagerKernel<R, E>, u32>(
&[EagerHandle::new(
Execution::start(kernel, input.client)
.inputs(&[EagerHandle::<R>::new(
&input.handle,
&input.strides,
&input.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
input.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -2,7 +2,7 @@ use std::marker::PhantomData;
use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
gpu::{gpu, Elem, Scope, Variable, Visibility},
@ -162,24 +162,20 @@ pub(crate) fn interpolate_nearest_launch<R: Runtime, E: JitElement>(
input: JitTensor<R, E, 4>,
output: JitTensor<R, E, 4>,
) -> JitTensor<R, E, 4> {
let kernel = InterpolateNearestEagerKernel::new();
let kernel = InterpolateNearestEagerKernel::<R, E>::new();
execute_dynamic::<R, InterpolateNearestEagerKernel<R, E>, u32>(
&[EagerHandle::new(
Execution::start(kernel, input.client)
.inputs(&[EagerHandle::<R>::new(
&input.handle,
&input.strides,
&input.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
input.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -2,7 +2,7 @@ use std::marker::PhantomData;
use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
gpu::{gpu, Elem, Scope, Variable, Visibility},
@ -221,24 +221,20 @@ pub(crate) fn interpolate_nearest_backward_launch<R: Runtime, E: JitElement>(
out_grad: JitTensor<R, E, 4>,
output: JitTensor<R, E, 4>,
) -> JitTensor<R, E, 4> {
let kernel = InterpolateNearestBackwardEagerKernel::new();
let kernel = InterpolateNearestBackwardEagerKernel::<R, E>::new();
execute_dynamic::<R, InterpolateNearestBackwardEagerKernel<R, E>, u32>(
&[EagerHandle::new(
Execution::start(kernel, out_grad.client)
.inputs(&[EagerHandle::<R>::new(
&out_grad.handle,
&out_grad.strides,
&out_grad.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
out_grad.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -1,10 +1,11 @@
use crate::codegen::dialect::gpu::{
gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Scope, Variable,
};
use crate::codegen::Execution;
use crate::{
codegen::{
dialect::gpu, execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler,
EagerHandle, InputInfo, OutputInfo, WorkgroupLaunch,
dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT},
@ -236,19 +237,15 @@ pub fn matmul_simple<R: Runtime, E: JitElement, const D: usize>(
workgroup_size_y,
);
let kernel = MatmulEagerKernel::new(workgroup_size_x, workgroup_size_y);
let kernel = MatmulEagerKernel::<R>::new(workgroup_size_x, workgroup_size_y);
execute_dynamic::<R, MatmulEagerKernel<R>, E>(
&[
EagerHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
Execution::start(kernel, rhs.client)
.inputs(&[
EagerHandle::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
],
&[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)],
None,
kernel,
WorkgroupLaunch::Custom(workgroup),
rhs.client,
);
])
.outputs(&[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)])
.execute(WorkgroupLaunch::Custom(workgroup));
out
}

View File

@ -1,5 +1,5 @@
use crate::{
codegen::{execute_dynamic, EagerHandle, WorkgroupLaunch},
codegen::{EagerHandle, Execution, WorkgroupLaunch},
element::JitElement,
ops::numeric::empty_device,
tensor::JitTensor,
@ -18,24 +18,20 @@ pub(crate) fn adaptive_avg_pool2d<R: Runtime, E: JitElement>(
let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]);
let output = empty_device(input.client.clone(), input.device.clone(), output_shape);
let kernel = AdaptivePool2dEagerKernel::new();
let kernel = AdaptivePool2dEagerKernel::<R, E>::new();
execute_dynamic::<R, AdaptivePool2dEagerKernel<R, E>, E>(
&[EagerHandle::new(
Execution::start(kernel, input.client)
.inputs(&[EagerHandle::<R>::new(
&input.handle,
&input.strides,
&input.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
input.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -2,14 +2,14 @@ use std::marker::PhantomData;
use crate::{
codegen::{
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
element::JitElement,
gpu::{gpu, Elem, Scope, Variable, Visibility},
kernel::{DynamicKernelSource, SourceTemplate},
tensor::JitTensor,
Compiler, Runtime, RuntimeInt,
Compiler, Runtime,
};
#[derive(new)]
@ -254,24 +254,20 @@ pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
output_buffer,
);
let kernel = AdaptiveAvgPool2dBackwardEagerKernel::new();
let kernel = AdaptiveAvgPool2dBackwardEagerKernel::<R, E>::new();
execute_dynamic::<R, AdaptiveAvgPool2dBackwardEagerKernel<R, E>, RuntimeInt<R>>(
&[EagerHandle::new(
Execution::start(kernel, x.client)
.inputs(&[EagerHandle::<R>::new(
&out_grad.handle,
&out_grad.strides,
&out_grad.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -1,12 +1,12 @@
use crate::{
codegen::{dialect::gpu::Variable, execute_dynamic, EagerHandle, WorkgroupLaunch},
codegen::{dialect::gpu::Variable, EagerHandle, Execution, WorkgroupLaunch},
element::JitElement,
gpu::{gpu, Elem, Item, Scope},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime, RuntimeInt,
Runtime,
};
use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape};
use burn_tensor::{ops::conv::calculate_pool_output_size, Shape};
use std::fmt::Debug;
use super::{Pool2dEagerKernel, PoolStrategy};
@ -99,27 +99,24 @@ pub(crate) fn avg_pool2d<R: Runtime, E: JitElement>(
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
let pool_strategy = AvgPool::new(kernel_size, count_include_pad);
let kernel = Pool2dEagerKernel::new(kernel_size, pool_strategy);
let kernel = Pool2dEagerKernel::<AvgPool, R, E>::new(kernel_size, pool_strategy);
execute_dynamic::<R, Pool2dEagerKernel<AvgPool, R, E>, RuntimeInt<R>>(
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
&[EagerHandle::new(
Execution::start(kernel, x.client)
.inputs(&[EagerHandle::<R>::new(&x.handle, &x.strides, &x.shape.dims)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&[
(stride[0] as u32).elem(),
(stride[1] as u32).elem(),
(dilation as u32).elem(),
(dilation as u32).elem(),
(padding[0] as u32).elem(),
(padding[1] as u32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
)])
.with_scalars(&[
stride[0] as u32,
stride[1] as u32,
dilation as u32,
dilation as u32,
padding[0] as u32,
padding[1] as u32,
])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -1,16 +1,14 @@
use burn_tensor::ElementConversion;
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{self, DynamicKernelSource, SourceTemplate},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime, RuntimeInt,
Runtime,
};
use std::marker::PhantomData;
@ -378,31 +376,28 @@ pub(crate) fn avg_pool2d_backward<R: Runtime, E: JitElement>(
let dilation = 1;
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
let kernel = AvgPool2dBackwardEagerKernel::new(kernel_size, count_include_pad);
let kernel = AvgPool2dBackwardEagerKernel::<R, E>::new(kernel_size, count_include_pad);
execute_dynamic::<R, AvgPool2dBackwardEagerKernel<R, E>, RuntimeInt<R>>(
&[EagerHandle::new(
Execution::start(kernel, x.client)
.inputs(&[EagerHandle::<R>::new(
&grad.handle,
&grad.strides,
&grad.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&[
(stride[0] as i32).elem(),
(stride[1] as i32).elem(),
dilation.elem(),
dilation.elem(),
(padding[0] as i32).elem(),
(padding[1] as i32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
)])
.with_scalars(&[
stride[0] as i32,
stride[1] as i32,
dilation,
dilation,
padding[0] as i32,
padding[1] as i32,
])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -1,14 +1,14 @@
use std::{fmt::Debug, marker::PhantomData};
use crate::{
codegen::{dialect::gpu::Variable, execute_dynamic, EagerHandle, WorkgroupLaunch},
codegen::{dialect::gpu::Variable, EagerHandle, Execution, WorkgroupLaunch},
element::JitElement,
gpu::{gpu, Elem, Item, Scope},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime, RuntimeInt,
Runtime,
};
use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape};
use burn_tensor::{ops::conv::calculate_pool_output_size, Shape};
use super::{Pool2dEagerKernel, PoolStrategy};
@ -137,27 +137,24 @@ pub(crate) fn max_pool2d<R: Runtime, E: JitElement>(
let shape_out = Shape::new([batch_size, channels, size_0, size_1]);
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
let kernel = Pool2dEagerKernel::new(kernel_size, MaxPool::default());
let kernel = Pool2dEagerKernel::<MaxPool<E>, R, E>::new(kernel_size, MaxPool::default());
execute_dynamic::<R, Pool2dEagerKernel<MaxPool<E>, R, E>, RuntimeInt<R>>(
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
&[EagerHandle::new(
Execution::start(kernel, x.client)
.inputs(&[EagerHandle::<R>::new(&x.handle, &x.strides, &x.shape.dims)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&[
(stride[0] as u32).elem(),
(stride[1] as u32).elem(),
(dilation[0] as u32).elem(),
(dilation[1] as u32).elem(),
(padding[0] as u32).elem(),
(padding[1] as u32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
)])
.with_scalars(&[
stride[0] as u32,
stride[1] as u32,
dilation[0] as u32,
dilation[1] as u32,
padding[0] as u32,
padding[1] as u32,
])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}
@ -190,26 +187,26 @@ pub(crate) fn max_pool2d_with_indices<R: Runtime, E: JitElement, I: JitElement>(
let output = empty_device(x.client.clone(), x.device.clone(), shape_out.clone());
let indices = empty_device(x.client.clone(), x.device.clone(), shape_out);
let kernel = Pool2dEagerKernel::new(kernel_size, MaxPoolWithIndices::default());
let kernel = Pool2dEagerKernel::<MaxPoolWithIndices<E>, R, E>::new(
kernel_size,
MaxPoolWithIndices::default(),
);
execute_dynamic::<R, Pool2dEagerKernel<MaxPoolWithIndices<E>, R, E>, I>(
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
&[
Execution::start(kernel, x.client)
.inputs(&[EagerHandle::<R>::new(&x.handle, &x.strides, &x.shape.dims)])
.outputs(&[
EagerHandle::new(&output.handle, &output.strides, &output.shape.dims),
EagerHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
],
Some(&[
(stride[0] as i32).elem(),
(stride[1] as i32).elem(),
(dilation[0] as i32).elem(),
(dilation[1] as i32).elem(),
(padding[0] as i32).elem(),
(padding[1] as i32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
])
.with_scalars(&[
stride[0] as i32,
stride[1] as i32,
dilation[0] as i32,
dilation[1] as i32,
padding[0] as i32,
padding[1] as i32,
])
.execute(WorkgroupLaunch::Output { pos: 0 });
(output, indices)
}

View File

@ -1,9 +1,7 @@
use burn_tensor::ElementConversion;
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
@ -332,30 +330,27 @@ pub(crate) fn max_pool2d_with_indices_backward<R: Runtime, E: JitElement, I: Jit
let indices = kernel::into_contiguous(indices);
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
let kernel = MaxPool2dWithIndicesBackwardEagerKernel::new(kernel_size);
let kernel = MaxPool2dWithIndicesBackwardEagerKernel::<R, E>::new(kernel_size);
execute_dynamic::<R, MaxPool2dWithIndicesBackwardEagerKernel<R, E>, I>(
&[
EagerHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
Execution::start(kernel, x.client)
.inputs(&[
EagerHandle::<R>::new(&indices.handle, &indices.strides, &indices.shape.dims),
EagerHandle::new(&grad.handle, &grad.strides, &grad.shape.dims),
],
&[EagerHandle::new(
])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&[
(stride[0] as i32).elem(),
(stride[1] as i32).elem(),
(dilation[0] as i32).elem(),
(dilation[1] as i32).elem(),
(padding[0] as i32).elem(),
(padding[1] as i32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
)])
.with_scalars(&[
stride[0] as i32,
stride[1] as i32,
dilation[0] as i32,
dilation[1] as i32,
padding[0] as i32,
padding[1] as i32,
])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -7,13 +7,13 @@ mod base;
mod max_pool2d;
mod max_pool2d_backward;
mod pool2d_shader;
pub(crate) use adaptive_pool2d_shader::*;
pub(crate) use pool2d_shader::*;
pub(crate) use adaptive_avg_pool2d::*;
pub(crate) use adaptive_avg_pool2d_backward::*;
pub(crate) use adaptive_pool2d_shader::*;
pub(crate) use avg_pool2d::*;
pub(crate) use avg_pool2d_backward::*;
pub(super) use base::*;
pub(crate) use max_pool2d::*;
pub(crate) use max_pool2d_backward::*;
pub(crate) use pool2d_shader::*;

View File

@ -3,7 +3,7 @@ use std::marker::PhantomData;
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
@ -152,24 +152,20 @@ pub fn reduce_dim_naive<
output: JitTensor<R, EO, D>,
dim: usize,
) -> JitTensor<R, EO, D> {
let kernel = NaiveReduceDimEagerKernel::new(dim);
let kernel = NaiveReduceDimEagerKernel::<RD, R, EI, EO>::new(dim);
execute_dynamic::<R, NaiveReduceDimEagerKernel<RD, R, EI, EO>, EI>(
&[EagerHandle::new(
Execution::start(kernel, input.client)
.inputs(&[EagerHandle::<R>::new(
&input.handle,
&input.strides,
&input.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Output { pos: 0 },
input.client,
);
)])
.execute(WorkgroupLaunch::Output { pos: 0 });
output
}

View File

@ -5,7 +5,7 @@ use crate::{
dialect::gpu::{
gpu, Branch, Elem, Scope, Synchronization, Variable, Visibility, WorkgroupSize,
},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
InputInfo, OutputInfo, WorkgroupLaunch,
},
compute::WorkGroup,
@ -259,7 +259,7 @@ pub fn reduce_dim_shared<
let divisible_shape =
n_invocation_per_workgroup as u32 * n_input_values_per_thread == reduce_group_size as u32;
let kernel = SharedReduceDimEagerKernel::new(
let kernel = SharedReduceDimEagerKernel::<RD, R, EI, EO>::new(
dim,
WORKGROUP_DEFAULT,
WORKGROUP_DEFAULT,
@ -267,22 +267,18 @@ pub fn reduce_dim_shared<
divisible_shape,
);
execute_dynamic::<R, SharedReduceDimEagerKernel<RD, R, EI, EO>, EI>(
&[EagerHandle::new(
Execution::start(kernel, input.client)
.inputs(&[EagerHandle::<R>::new(
&input.handle,
&input.strides,
&input.shape.dims,
)],
&[EagerHandle::new(
)])
.outputs(&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
None,
kernel,
WorkgroupLaunch::Custom(grid),
input.client,
);
)])
.execute(WorkgroupLaunch::Custom(grid));
output
}