mirror of https://github.com/tracel-ai/burn.git
Refactor execute_dynamic with Execution struct (#1550)
This commit is contained in:
parent
efc3b2d243
commit
edcd92f13d
|
@ -220,7 +220,7 @@ pub(crate) fn run_backend_comparison_benchmarks(
|
|||
let status = run_cargo("bench", &args).unwrap();
|
||||
if !status.success() {
|
||||
println!(
|
||||
"Benchmark {} didn't ran successfully on the backend {}",
|
||||
"Benchmark {} didn't run successfully on the backend {}",
|
||||
bench_str, backend_str
|
||||
);
|
||||
continue;
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# Burn JIT Backend
|
||||
|
||||
Generic backend that can be compiled just-in-time (JIT) to any shader language target
|
||||
In progress: At the moment, only WGSL compilation is supported, and some kernels still rely on pure WGSL
|
||||
|
|
|
@ -34,27 +34,8 @@ pub fn execute_static<R, K, E>(
|
|||
R: Runtime,
|
||||
E: JitElement,
|
||||
{
|
||||
execute_static_::<R, K, E, E, E>(inputs, outputs, scalar_elems, None, None, launch, client)
|
||||
}
|
||||
|
||||
fn execute_static_<R, K, E1, E2, E3>(
|
||||
inputs: &[EagerHandle<R>],
|
||||
outputs: &[EagerHandle<R>],
|
||||
scalars_1: Option<&[E1]>,
|
||||
scalars_2: Option<&[E2]>,
|
||||
scalars_3: Option<&[E3]>,
|
||||
launch: WorkgroupLaunch,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
) where
|
||||
K: StaticKernelSource + 'static,
|
||||
R: Runtime,
|
||||
E1: JitElement,
|
||||
E2: JitElement,
|
||||
E3: JitElement,
|
||||
{
|
||||
let settings = execute_settings(
|
||||
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
|
||||
);
|
||||
let settings =
|
||||
execute_settings::<R, E, E, E>(inputs, outputs, scalar_elems, None, None, launch, &client);
|
||||
let mut handles = settings.handles_tensors;
|
||||
let workgroup = settings.workgroup;
|
||||
|
||||
|
@ -64,6 +45,7 @@ fn execute_static_<R, K, E1, E2, E3>(
|
|||
}
|
||||
|
||||
let kernel = Box::new(StaticKernel::<K>::new(workgroup));
|
||||
|
||||
client.execute(kernel, &handles);
|
||||
}
|
||||
|
||||
|
@ -128,7 +110,7 @@ where
|
|||
/// Execute a dynamic kernel.
|
||||
#[allow(unused)]
|
||||
pub fn execute(self, launch: WorkgroupLaunch) {
|
||||
execute_dynamic_::<R, K, f32, f32, f32>(
|
||||
execute_dynamic::<R, K, f32, f32, f32>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
None,
|
||||
|
@ -163,7 +145,7 @@ where
|
|||
/// Execute a dynamic kernel.
|
||||
#[allow(unused)]
|
||||
pub fn execute(self, launch: WorkgroupLaunch) {
|
||||
execute_dynamic_::<R, K, E, f32, f32>(
|
||||
execute_dynamic::<R, K, E, f32, f32>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
Some(self.scalars.0),
|
||||
|
@ -203,7 +185,7 @@ where
|
|||
K: DynamicKernelSource + 'static,
|
||||
R: Runtime,
|
||||
{
|
||||
execute_dynamic_::<R, K, E1, E2, f32>(
|
||||
execute_dynamic::<R, K, E1, E2, f32>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
Some(self.scalars.0),
|
||||
|
@ -227,7 +209,7 @@ where
|
|||
/// Execute a dynamic kernel.
|
||||
#[allow(unused)]
|
||||
pub fn execute(self, launch: WorkgroupLaunch) {
|
||||
execute_dynamic_::<R, K, E1, E2, E3>(
|
||||
execute_dynamic::<R, K, E1, E2, E3>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
Some(self.scalars.0),
|
||||
|
@ -240,33 +222,8 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Execute a dynamic kernel.
|
||||
pub fn execute_dynamic<R, K, E>(
|
||||
inputs: &[EagerHandle<R>],
|
||||
outputs: &[EagerHandle<R>],
|
||||
scalar_elems: Option<&[E]>,
|
||||
kernel: K,
|
||||
launch: WorkgroupLaunch,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
) where
|
||||
K: DynamicKernelSource + 'static,
|
||||
R: Runtime,
|
||||
E: JitElement,
|
||||
{
|
||||
execute_dynamic_::<R, K, E, E, E>(
|
||||
inputs,
|
||||
outputs,
|
||||
scalar_elems,
|
||||
None,
|
||||
None,
|
||||
kernel,
|
||||
launch,
|
||||
client,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn execute_dynamic_<R, K, E1, E2, E3>(
|
||||
fn execute_dynamic<R, K, E1, E2, E3>(
|
||||
inputs: &[EagerHandle<R>],
|
||||
outputs: &[EagerHandle<R>],
|
||||
scalars_1: Option<&[E1]>,
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::{any::TypeId, marker::PhantomData};
|
|||
|
||||
use crate::{
|
||||
codegen::{
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
gpu::{gpu, Scope, Variable, Visibility},
|
||||
|
@ -21,7 +21,7 @@ pub fn cast<R: Runtime, EI: JitElement, EO: JitElement, const D: usize>(
|
|||
return JitTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle);
|
||||
}
|
||||
|
||||
let kernel = CastEagerKernel::new();
|
||||
let kernel = CastEagerKernel::<R, EI, EO>::new();
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
|
||||
let output = JitTensor::new(
|
||||
|
@ -31,22 +31,18 @@ pub fn cast<R: Runtime, EI: JitElement, EO: JitElement, const D: usize>(
|
|||
buffer,
|
||||
);
|
||||
|
||||
execute_dynamic::<R, CastEagerKernel<R, EI, EO>, u32>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, tensor.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
tensor.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::marker::PhantomData;
|
|||
|
||||
use crate::{
|
||||
codegen::{
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
|
||||
|
@ -20,7 +20,7 @@ use crate::{
|
|||
pub fn bool_cast<R: Runtime, EO: JitElement, const D: usize>(
|
||||
tensor: JitTensor<R, u32, D>,
|
||||
) -> JitTensor<R, EO, D> {
|
||||
let kernel = BoolCastEagerKernel::new();
|
||||
let kernel = BoolCastEagerKernel::<R, EO>::new();
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<EO>());
|
||||
let output = JitTensor::new(
|
||||
|
@ -30,22 +30,18 @@ pub fn bool_cast<R: Runtime, EO: JitElement, const D: usize>(
|
|||
buffer,
|
||||
);
|
||||
|
||||
execute_dynamic::<R, BoolCastEagerKernel<R, EO>, u32>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, tensor.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
tensor.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::marker::PhantomData;
|
|||
|
||||
use crate::{
|
||||
codegen::{
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
gpu::{gpu, Elem, IndexOffsetGlobalWithLayout, Scope, Variable, Visibility},
|
||||
|
@ -18,9 +18,9 @@ pub(crate) struct IntoContiguousShader {
|
|||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct IntoContiguousEagerKernel<R: Runtime, EO: JitElement> {
|
||||
pub(crate) struct IntoContiguousEagerKernel<R: Runtime, E: JitElement> {
|
||||
_runtime: PhantomData<R>,
|
||||
_elem_out: PhantomData<EO>,
|
||||
_elem_out: PhantomData<E>,
|
||||
}
|
||||
|
||||
/// Make a jit tensor contiguous.
|
||||
|
@ -31,7 +31,7 @@ pub fn into_contiguous<R: Runtime, E: JitElement, const D: usize>(
|
|||
return tensor;
|
||||
}
|
||||
|
||||
let kernel = IntoContiguousEagerKernel::new();
|
||||
let kernel = IntoContiguousEagerKernel::<R, E>::new();
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = JitTensor::new(
|
||||
|
@ -41,22 +41,18 @@ pub fn into_contiguous<R: Runtime, E: JitElement, const D: usize>(
|
|||
buffer,
|
||||
);
|
||||
|
||||
execute_dynamic::<R, IntoContiguousEagerKernel<R, E>, u32>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, tensor.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
tensor.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
use burn_tensor::{
|
||||
ops::{conv::calculate_conv_output_size, ConvOptions},
|
||||
ElementConversion, Shape,
|
||||
Shape,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
|
@ -17,7 +17,7 @@ use crate::{
|
|||
reshape,
|
||||
},
|
||||
tensor::JitTensor,
|
||||
Runtime, RuntimeInt,
|
||||
Runtime,
|
||||
};
|
||||
|
||||
#[derive(new)]
|
||||
|
@ -335,32 +335,29 @@ pub(crate) fn conv2d<R: Runtime, E: JitElement>(
|
|||
}
|
||||
};
|
||||
|
||||
let kernel = Conv2dEagerKernel::new();
|
||||
let kernel = Conv2dEagerKernel::<R, E>::new();
|
||||
|
||||
execute_dynamic::<R, Conv2dEagerKernel<R, E>, RuntimeInt<R>>(
|
||||
&[
|
||||
EagerHandle::new(&input.handle, &input.strides, &input.shape.dims),
|
||||
Execution::start(kernel, input.client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&input.handle, &input.strides, &input.shape.dims),
|
||||
EagerHandle::new(&weight.handle, &weight.strides, &weight.shape.dims),
|
||||
EagerHandle::new(&bias.handle, &bias.strides, &bias.shape.dims),
|
||||
],
|
||||
&[EagerHandle::new(
|
||||
])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
Some(&[
|
||||
(options.stride[0] as u32).elem(),
|
||||
(options.stride[1] as u32).elem(),
|
||||
(options.dilation[0] as u32).elem(),
|
||||
(options.dilation[1] as u32).elem(),
|
||||
(options.padding[0] as u32).elem(),
|
||||
(options.padding[1] as u32).elem(),
|
||||
(options.groups as u32).elem(),
|
||||
]),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
input.client,
|
||||
);
|
||||
)])
|
||||
.with_scalars(&[
|
||||
options.stride[0] as u32,
|
||||
options.stride[1] as u32,
|
||||
options.dilation[0] as u32,
|
||||
options.dilation[1] as u32,
|
||||
options.padding[0] as u32,
|
||||
options.padding[1] as u32,
|
||||
options.groups as u32,
|
||||
])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ use crate::{
|
|||
tensor::JitTensor,
|
||||
Compiler, Runtime,
|
||||
};
|
||||
use burn_tensor::{ops::ConvTransposeOptions, Element, ElementConversion, Shape};
|
||||
use burn_tensor::{ops::ConvTransposeOptions, Element, Shape};
|
||||
|
||||
#[derive(new)]
|
||||
struct Conv2dTransposeEagerKernel<R, E> {
|
||||
|
@ -399,13 +399,13 @@ pub(crate) fn conv_transpose2d<R: Runtime, E: JitElement + Element>(
|
|||
&output.shape.dims,
|
||||
)])
|
||||
.with_scalars(&[
|
||||
(options.stride[0] as u32).elem::<u32>(),
|
||||
(options.stride[1] as u32).elem(),
|
||||
(options.dilation[0] as u32).elem(),
|
||||
(options.dilation[1] as u32).elem(),
|
||||
(options.padding[0] as u32).elem(),
|
||||
(options.padding[1] as u32).elem(),
|
||||
(options.groups as u32).elem(),
|
||||
options.stride[0] as u32,
|
||||
options.stride[1] as u32,
|
||||
options.dilation[0] as u32,
|
||||
options.dilation[1] as u32,
|
||||
options.padding[0] as u32,
|
||||
options.padding[1] as u32,
|
||||
options.groups as u32,
|
||||
])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
kernel::{DynamicKernelSource, SourceTemplate},
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime, RuntimeInt,
|
||||
Runtime,
|
||||
};
|
||||
use burn_tensor::ElementConversion;
|
||||
use std::marker::PhantomData;
|
||||
|
@ -128,30 +128,27 @@ pub(crate) fn flip_on_output<R: Runtime, E: JitElement, const D: usize>(
|
|||
output: JitTensor<R, E, D>,
|
||||
indices: &[usize],
|
||||
) -> JitTensor<R, E, D> {
|
||||
let mut scalars = Vec::with_capacity(D);
|
||||
let mut scalars: Vec<u32> = Vec::with_capacity(D);
|
||||
|
||||
for i in 0..D {
|
||||
scalars.push((indices.contains(&i) as u32).elem());
|
||||
}
|
||||
|
||||
let kernel = FlipEagerKernel::new(D);
|
||||
let kernel = FlipEagerKernel::<R, E>::new(D);
|
||||
|
||||
execute_dynamic::<R, FlipEagerKernel<R, E>, RuntimeInt<R>>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, tensor.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
Some(&scalars),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
tensor.client,
|
||||
);
|
||||
)])
|
||||
.with_scalars(&scalars)
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
use crate::codegen::dialect::gpu::{gpu, Elem, Scope, Variable};
|
||||
use crate::codegen::Execution;
|
||||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu, execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler,
|
||||
EagerHandle, InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
kernel::{self, DynamicKernelSource, SourceTemplate},
|
||||
|
@ -130,23 +131,19 @@ pub(crate) fn gather<R: Runtime, E: JitElement, I: JitElement, const D: usize>(
|
|||
) -> JitTensor<R, E, D> {
|
||||
let shape_output = indices.shape.clone();
|
||||
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
|
||||
let kernel = GatherEagerKernel::new(dim);
|
||||
let kernel = GatherEagerKernel::<R, E>::new(dim);
|
||||
|
||||
execute_dynamic::<R, GatherEagerKernel<R, E>, E>(
|
||||
&[
|
||||
EagerHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||
Execution::start(kernel, tensor.client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||
EagerHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
|
||||
],
|
||||
&[EagerHandle::new(
|
||||
])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
tensor.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
|
@ -126,23 +126,20 @@ pub(crate) fn repeat<R: Runtime, E: JitElement, const D1: usize>(
|
|||
handle,
|
||||
);
|
||||
|
||||
let kernel = RepeatEagerKernel::new(dim, D1);
|
||||
let kernel = RepeatEagerKernel::<R, E>::new(dim, D1);
|
||||
|
||||
execute_dynamic::<R, RepeatEagerKernel<R, E>, E>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, input.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&input.handle,
|
||||
&input.strides,
|
||||
&input.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
input.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use crate::codegen::dialect::gpu::{gpu, Branch, Elem, Scope, Variable};
|
||||
use crate::codegen::Execution;
|
||||
use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT};
|
||||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu, execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler,
|
||||
EagerHandle, InputInfo, WorkgroupLaunch,
|
||||
dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
InputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
kernel::{self, DynamicKernelSource, SourceTemplate},
|
||||
|
@ -192,7 +193,7 @@ pub(crate) fn scatter<R: Runtime, E: JitElement, I: JitElement, const D: usize>(
|
|||
false => tensor.copy(),
|
||||
};
|
||||
|
||||
let kernel = ScatterEagerKernel::new(dim);
|
||||
let kernel = ScatterEagerKernel::<R, E>::new(dim);
|
||||
let mut strides = [0; D];
|
||||
let mut current = 1;
|
||||
let mut num_elems_per_workgroup = 1;
|
||||
|
@ -209,22 +210,19 @@ pub(crate) fn scatter<R: Runtime, E: JitElement, I: JitElement, const D: usize>(
|
|||
current *= val;
|
||||
num_elems_per_workgroup *= tensor.shape.dims[index];
|
||||
});
|
||||
|
||||
// Fake strides of the virtual output where the strides of dim is hardcoded to one.
|
||||
indices.strides = strides;
|
||||
|
||||
let workgroup = elemwise_workgroup(num_elems_per_workgroup, WORKGROUP_DEFAULT);
|
||||
execute_dynamic::<R, ScatterEagerKernel<R, E>, E>(
|
||||
&[
|
||||
EagerHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||
|
||||
Execution::start(kernel, indices.client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||
EagerHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
|
||||
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),
|
||||
],
|
||||
&[],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Custom(workgroup),
|
||||
indices.client,
|
||||
);
|
||||
])
|
||||
.execute(WorkgroupLaunch::Custom(workgroup));
|
||||
|
||||
tensor
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
|
@ -125,27 +125,23 @@ pub(crate) fn select<R: Runtime, E: JitElement, I: JitElement, const D: usize>(
|
|||
shape_output.dims[dim] = indices.shape.dims[0];
|
||||
|
||||
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
|
||||
let kernel = SelectEagerKernel::new(dim);
|
||||
let kernel = SelectEagerKernel::<R, E>::new(dim);
|
||||
|
||||
execute_dynamic::<R, SelectEagerKernel<R, E>, E>(
|
||||
&[
|
||||
EagerHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||
Execution::start(kernel, tensor.client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||
// This is a current hacks because the info buffer that contains the strides and shapes is
|
||||
// hardcoded to only contains information about tensors of the same rank. However, since
|
||||
// we don't rely on the shape and stride of the indices tensors, it doesn't matter
|
||||
// which value we put, it just needs to be of the same rank.
|
||||
EagerHandle::new(&indices.handle, &[1; D], &[1; D]),
|
||||
],
|
||||
&[EagerHandle::new(
|
||||
])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
tensor.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Branch, Elem, Item, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
|
||||
InputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
|
@ -205,23 +205,18 @@ pub(crate) fn select_assign<R: Runtime, E: JitElement, I: JitElement, const D: u
|
|||
num_elems_per_workgroup *= tensor.shape.dims[index];
|
||||
});
|
||||
|
||||
let kernel = SelectAssignEagerKernel::new(dim);
|
||||
let kernel = SelectAssignEagerKernel::<R, E>::new(dim);
|
||||
let workgroup = elemwise_workgroup(num_elems_per_workgroup, WORKGROUP_DEFAULT);
|
||||
|
||||
execute_dynamic::<R, SelectAssignEagerKernel<R, E>, E>(
|
||||
&[
|
||||
EagerHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||
Execution::start(kernel, indices.client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),
|
||||
// We use the custom strides here instead of the shape, since we don't use it in the
|
||||
// kernel, but we need to put the right number of dimensions (rank).
|
||||
EagerHandle::new(&indices.handle, &strides, &strides),
|
||||
],
|
||||
&[],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Custom(workgroup),
|
||||
indices.client,
|
||||
);
|
||||
])
|
||||
.execute(WorkgroupLaunch::Custom(workgroup));
|
||||
|
||||
tensor
|
||||
}
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
kernel::{DynamicKernelSource, SourceTemplate},
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime, RuntimeInt,
|
||||
Runtime,
|
||||
};
|
||||
use burn_tensor::{ElementConversion, Shape};
|
||||
use std::{marker::PhantomData, ops::Range};
|
||||
|
@ -125,31 +125,28 @@ pub(crate) fn slice_on_output<R: Runtime, E: JitElement, const D1: usize, const
|
|||
output: JitTensor<R, E, D1>,
|
||||
indices: [Range<usize>; D2],
|
||||
) -> JitTensor<R, E, D1> {
|
||||
let mut scalars = Vec::with_capacity(D1);
|
||||
let mut scalars: Vec<i32> = Vec::with_capacity(D1);
|
||||
|
||||
for i in 0..D1 {
|
||||
let start = indices.get(i).map(|index| index.start).unwrap_or(0);
|
||||
scalars.push((start as i32).elem());
|
||||
}
|
||||
|
||||
let kernel = SliceEagerKernel::new(D1);
|
||||
let kernel = SliceEagerKernel::<R, E>::new(D1);
|
||||
|
||||
execute_dynamic::<R, SliceEagerKernel<R, E>, RuntimeInt<R>>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, tensor.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
Some(&scalars),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
tensor.client,
|
||||
);
|
||||
)])
|
||||
.with_scalars(&scalars)
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
|
||||
InputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
kernel::{DynamicKernelSource, SourceTemplate},
|
||||
tensor::JitTensor,
|
||||
Runtime, RuntimeInt,
|
||||
Runtime,
|
||||
};
|
||||
use burn_tensor::ElementConversion;
|
||||
use std::{marker::PhantomData, ops::Range};
|
||||
|
@ -130,26 +130,22 @@ pub(crate) fn slice_assign<R: Runtime, E: JitElement, const D1: usize, const D2:
|
|||
true => tensor,
|
||||
false => tensor.copy(),
|
||||
};
|
||||
let mut scalars = Vec::with_capacity(D1);
|
||||
let mut scalars: Vec<i32> = Vec::with_capacity(D1);
|
||||
|
||||
for i in 0..D1 {
|
||||
let start = indices.get(i).map(|index| index.start).unwrap_or(0);
|
||||
scalars.push((start as i32).elem());
|
||||
}
|
||||
|
||||
let kernel = SliceAssignEagerKernel::new(D1);
|
||||
let kernel = SliceAssignEagerKernel::<R, E>::new(D1);
|
||||
|
||||
execute_dynamic::<R, SliceAssignEagerKernel<R, E>, RuntimeInt<R>>(
|
||||
&[
|
||||
EagerHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||
Execution::start(kernel, value.client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
|
||||
EagerHandle::new(&value.handle, &value.strides, &value.shape.dims),
|
||||
],
|
||||
&[],
|
||||
Some(&scalars),
|
||||
kernel,
|
||||
WorkgroupLaunch::Input { pos: 0 },
|
||||
value.client,
|
||||
);
|
||||
])
|
||||
.with_scalars(&scalars)
|
||||
.execute(WorkgroupLaunch::Input { pos: 0 });
|
||||
|
||||
tensor
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::marker::PhantomData;
|
|||
|
||||
use crate::{
|
||||
codegen::{
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
|
@ -406,24 +406,20 @@ pub(crate) fn interpolate_bicubic_launch<R: Runtime, E: JitElement>(
|
|||
input: JitTensor<R, E, 4>,
|
||||
output: JitTensor<R, E, 4>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let kernel = InterpolateBicubicEagerKernel::new();
|
||||
let kernel = InterpolateBicubicEagerKernel::<R, E>::new();
|
||||
|
||||
execute_dynamic::<R, InterpolateBicubicEagerKernel<R, E>, u32>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, input.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&input.handle,
|
||||
&input.strides,
|
||||
&input.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
input.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::marker::PhantomData;
|
|||
|
||||
use crate::{
|
||||
codegen::{
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
|
@ -226,24 +226,20 @@ pub(crate) fn interpolate_bilinear_launch<R: Runtime, E: JitElement>(
|
|||
input: JitTensor<R, E, 4>,
|
||||
output: JitTensor<R, E, 4>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let kernel = InterpolateBilinearEagerKernel::new();
|
||||
let kernel = InterpolateBilinearEagerKernel::<R, E>::new();
|
||||
|
||||
execute_dynamic::<R, InterpolateBilinearEagerKernel<R, E>, u32>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, input.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&input.handle,
|
||||
&input.strides,
|
||||
&input.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
input.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::marker::PhantomData;
|
|||
|
||||
use crate::{
|
||||
codegen::{
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
|
@ -162,24 +162,20 @@ pub(crate) fn interpolate_nearest_launch<R: Runtime, E: JitElement>(
|
|||
input: JitTensor<R, E, 4>,
|
||||
output: JitTensor<R, E, 4>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let kernel = InterpolateNearestEagerKernel::new();
|
||||
let kernel = InterpolateNearestEagerKernel::<R, E>::new();
|
||||
|
||||
execute_dynamic::<R, InterpolateNearestEagerKernel<R, E>, u32>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, input.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&input.handle,
|
||||
&input.strides,
|
||||
&input.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
input.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::marker::PhantomData;
|
|||
|
||||
use crate::{
|
||||
codegen::{
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
|
@ -221,24 +221,20 @@ pub(crate) fn interpolate_nearest_backward_launch<R: Runtime, E: JitElement>(
|
|||
out_grad: JitTensor<R, E, 4>,
|
||||
output: JitTensor<R, E, 4>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let kernel = InterpolateNearestBackwardEagerKernel::new();
|
||||
let kernel = InterpolateNearestBackwardEagerKernel::<R, E>::new();
|
||||
|
||||
execute_dynamic::<R, InterpolateNearestBackwardEagerKernel<R, E>, u32>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, out_grad.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&out_grad.handle,
|
||||
&out_grad.strides,
|
||||
&out_grad.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
out_grad.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
use crate::codegen::dialect::gpu::{
|
||||
gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Scope, Variable,
|
||||
};
|
||||
use crate::codegen::Execution;
|
||||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu, execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler,
|
||||
EagerHandle, InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT},
|
||||
|
@ -236,19 +237,15 @@ pub fn matmul_simple<R: Runtime, E: JitElement, const D: usize>(
|
|||
workgroup_size_y,
|
||||
);
|
||||
|
||||
let kernel = MatmulEagerKernel::new(workgroup_size_x, workgroup_size_y);
|
||||
let kernel = MatmulEagerKernel::<R>::new(workgroup_size_x, workgroup_size_y);
|
||||
|
||||
execute_dynamic::<R, MatmulEagerKernel<R>, E>(
|
||||
&[
|
||||
EagerHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
Execution::start(kernel, rhs.client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
],
|
||||
&[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Custom(workgroup),
|
||||
rhs.client,
|
||||
);
|
||||
])
|
||||
.outputs(&[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)])
|
||||
.execute(WorkgroupLaunch::Custom(workgroup));
|
||||
|
||||
out
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
codegen::{execute_dynamic, EagerHandle, WorkgroupLaunch},
|
||||
codegen::{EagerHandle, Execution, WorkgroupLaunch},
|
||||
element::JitElement,
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
|
@ -18,24 +18,20 @@ pub(crate) fn adaptive_avg_pool2d<R: Runtime, E: JitElement>(
|
|||
let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]);
|
||||
let output = empty_device(input.client.clone(), input.device.clone(), output_shape);
|
||||
|
||||
let kernel = AdaptivePool2dEagerKernel::new();
|
||||
let kernel = AdaptivePool2dEagerKernel::<R, E>::new();
|
||||
|
||||
execute_dynamic::<R, AdaptivePool2dEagerKernel<R, E>, E>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, input.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&input.handle,
|
||||
&input.strides,
|
||||
&input.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
input.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -2,14 +2,14 @@ use std::marker::PhantomData;
|
|||
|
||||
use crate::{
|
||||
codegen::{
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo,
|
||||
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
|
||||
OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
kernel::{DynamicKernelSource, SourceTemplate},
|
||||
tensor::JitTensor,
|
||||
Compiler, Runtime, RuntimeInt,
|
||||
Compiler, Runtime,
|
||||
};
|
||||
|
||||
#[derive(new)]
|
||||
|
@ -254,24 +254,20 @@ pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
|
|||
output_buffer,
|
||||
);
|
||||
|
||||
let kernel = AdaptiveAvgPool2dBackwardEagerKernel::new();
|
||||
let kernel = AdaptiveAvgPool2dBackwardEagerKernel::<R, E>::new();
|
||||
|
||||
execute_dynamic::<R, AdaptiveAvgPool2dBackwardEagerKernel<R, E>, RuntimeInt<R>>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, x.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&out_grad.handle,
|
||||
&out_grad.strides,
|
||||
&out_grad.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
x.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
use crate::{
|
||||
codegen::{dialect::gpu::Variable, execute_dynamic, EagerHandle, WorkgroupLaunch},
|
||||
codegen::{dialect::gpu::Variable, EagerHandle, Execution, WorkgroupLaunch},
|
||||
element::JitElement,
|
||||
gpu::{gpu, Elem, Item, Scope},
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime, RuntimeInt,
|
||||
Runtime,
|
||||
};
|
||||
use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape};
|
||||
use burn_tensor::{ops::conv::calculate_pool_output_size, Shape};
|
||||
use std::fmt::Debug;
|
||||
|
||||
use super::{Pool2dEagerKernel, PoolStrategy};
|
||||
|
@ -99,27 +99,24 @@ pub(crate) fn avg_pool2d<R: Runtime, E: JitElement>(
|
|||
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
|
||||
|
||||
let pool_strategy = AvgPool::new(kernel_size, count_include_pad);
|
||||
let kernel = Pool2dEagerKernel::new(kernel_size, pool_strategy);
|
||||
let kernel = Pool2dEagerKernel::<AvgPool, R, E>::new(kernel_size, pool_strategy);
|
||||
|
||||
execute_dynamic::<R, Pool2dEagerKernel<AvgPool, R, E>, RuntimeInt<R>>(
|
||||
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, x.client)
|
||||
.inputs(&[EagerHandle::<R>::new(&x.handle, &x.strides, &x.shape.dims)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
Some(&[
|
||||
(stride[0] as u32).elem(),
|
||||
(stride[1] as u32).elem(),
|
||||
(dilation as u32).elem(),
|
||||
(dilation as u32).elem(),
|
||||
(padding[0] as u32).elem(),
|
||||
(padding[1] as u32).elem(),
|
||||
]),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
x.client,
|
||||
);
|
||||
)])
|
||||
.with_scalars(&[
|
||||
stride[0] as u32,
|
||||
stride[1] as u32,
|
||||
dilation as u32,
|
||||
dilation as u32,
|
||||
padding[0] as u32,
|
||||
padding[1] as u32,
|
||||
])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
use burn_tensor::ElementConversion;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
kernel::{self, DynamicKernelSource, SourceTemplate},
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime, RuntimeInt,
|
||||
Runtime,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
|
@ -378,31 +376,28 @@ pub(crate) fn avg_pool2d_backward<R: Runtime, E: JitElement>(
|
|||
let dilation = 1;
|
||||
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
|
||||
let kernel = AvgPool2dBackwardEagerKernel::new(kernel_size, count_include_pad);
|
||||
let kernel = AvgPool2dBackwardEagerKernel::<R, E>::new(kernel_size, count_include_pad);
|
||||
|
||||
execute_dynamic::<R, AvgPool2dBackwardEagerKernel<R, E>, RuntimeInt<R>>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, x.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&grad.handle,
|
||||
&grad.strides,
|
||||
&grad.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
Some(&[
|
||||
(stride[0] as i32).elem(),
|
||||
(stride[1] as i32).elem(),
|
||||
dilation.elem(),
|
||||
dilation.elem(),
|
||||
(padding[0] as i32).elem(),
|
||||
(padding[1] as i32).elem(),
|
||||
]),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
x.client,
|
||||
);
|
||||
)])
|
||||
.with_scalars(&[
|
||||
stride[0] as i32,
|
||||
stride[1] as i32,
|
||||
dilation,
|
||||
dilation,
|
||||
padding[0] as i32,
|
||||
padding[1] as i32,
|
||||
])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
use std::{fmt::Debug, marker::PhantomData};
|
||||
|
||||
use crate::{
|
||||
codegen::{dialect::gpu::Variable, execute_dynamic, EagerHandle, WorkgroupLaunch},
|
||||
codegen::{dialect::gpu::Variable, EagerHandle, Execution, WorkgroupLaunch},
|
||||
element::JitElement,
|
||||
gpu::{gpu, Elem, Item, Scope},
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime, RuntimeInt,
|
||||
Runtime,
|
||||
};
|
||||
use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape};
|
||||
use burn_tensor::{ops::conv::calculate_pool_output_size, Shape};
|
||||
|
||||
use super::{Pool2dEagerKernel, PoolStrategy};
|
||||
|
||||
|
@ -137,27 +137,24 @@ pub(crate) fn max_pool2d<R: Runtime, E: JitElement>(
|
|||
let shape_out = Shape::new([batch_size, channels, size_0, size_1]);
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
|
||||
|
||||
let kernel = Pool2dEagerKernel::new(kernel_size, MaxPool::default());
|
||||
let kernel = Pool2dEagerKernel::<MaxPool<E>, R, E>::new(kernel_size, MaxPool::default());
|
||||
|
||||
execute_dynamic::<R, Pool2dEagerKernel<MaxPool<E>, R, E>, RuntimeInt<R>>(
|
||||
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, x.client)
|
||||
.inputs(&[EagerHandle::<R>::new(&x.handle, &x.strides, &x.shape.dims)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
Some(&[
|
||||
(stride[0] as u32).elem(),
|
||||
(stride[1] as u32).elem(),
|
||||
(dilation[0] as u32).elem(),
|
||||
(dilation[1] as u32).elem(),
|
||||
(padding[0] as u32).elem(),
|
||||
(padding[1] as u32).elem(),
|
||||
]),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
x.client,
|
||||
);
|
||||
)])
|
||||
.with_scalars(&[
|
||||
stride[0] as u32,
|
||||
stride[1] as u32,
|
||||
dilation[0] as u32,
|
||||
dilation[1] as u32,
|
||||
padding[0] as u32,
|
||||
padding[1] as u32,
|
||||
])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
@ -190,26 +187,26 @@ pub(crate) fn max_pool2d_with_indices<R: Runtime, E: JitElement, I: JitElement>(
|
|||
let output = empty_device(x.client.clone(), x.device.clone(), shape_out.clone());
|
||||
let indices = empty_device(x.client.clone(), x.device.clone(), shape_out);
|
||||
|
||||
let kernel = Pool2dEagerKernel::new(kernel_size, MaxPoolWithIndices::default());
|
||||
let kernel = Pool2dEagerKernel::<MaxPoolWithIndices<E>, R, E>::new(
|
||||
kernel_size,
|
||||
MaxPoolWithIndices::default(),
|
||||
);
|
||||
|
||||
execute_dynamic::<R, Pool2dEagerKernel<MaxPoolWithIndices<E>, R, E>, I>(
|
||||
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
|
||||
&[
|
||||
Execution::start(kernel, x.client)
|
||||
.inputs(&[EagerHandle::<R>::new(&x.handle, &x.strides, &x.shape.dims)])
|
||||
.outputs(&[
|
||||
EagerHandle::new(&output.handle, &output.strides, &output.shape.dims),
|
||||
EagerHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
|
||||
],
|
||||
Some(&[
|
||||
(stride[0] as i32).elem(),
|
||||
(stride[1] as i32).elem(),
|
||||
(dilation[0] as i32).elem(),
|
||||
(dilation[1] as i32).elem(),
|
||||
(padding[0] as i32).elem(),
|
||||
(padding[1] as i32).elem(),
|
||||
]),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
x.client,
|
||||
);
|
||||
])
|
||||
.with_scalars(&[
|
||||
stride[0] as i32,
|
||||
stride[1] as i32,
|
||||
dilation[0] as i32,
|
||||
dilation[1] as i32,
|
||||
padding[0] as i32,
|
||||
padding[1] as i32,
|
||||
])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
(output, indices)
|
||||
}
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
use burn_tensor::ElementConversion;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
|
@ -332,30 +330,27 @@ pub(crate) fn max_pool2d_with_indices_backward<R: Runtime, E: JitElement, I: Jit
|
|||
let indices = kernel::into_contiguous(indices);
|
||||
|
||||
let output = empty_device(x.client.clone(), x.device.clone(), x.shape.clone());
|
||||
let kernel = MaxPool2dWithIndicesBackwardEagerKernel::new(kernel_size);
|
||||
let kernel = MaxPool2dWithIndicesBackwardEagerKernel::<R, E>::new(kernel_size);
|
||||
|
||||
execute_dynamic::<R, MaxPool2dWithIndicesBackwardEagerKernel<R, E>, I>(
|
||||
&[
|
||||
EagerHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
|
||||
Execution::start(kernel, x.client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&indices.handle, &indices.strides, &indices.shape.dims),
|
||||
EagerHandle::new(&grad.handle, &grad.strides, &grad.shape.dims),
|
||||
],
|
||||
&[EagerHandle::new(
|
||||
])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
Some(&[
|
||||
(stride[0] as i32).elem(),
|
||||
(stride[1] as i32).elem(),
|
||||
(dilation[0] as i32).elem(),
|
||||
(dilation[1] as i32).elem(),
|
||||
(padding[0] as i32).elem(),
|
||||
(padding[1] as i32).elem(),
|
||||
]),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
x.client,
|
||||
);
|
||||
)])
|
||||
.with_scalars(&[
|
||||
stride[0] as i32,
|
||||
stride[1] as i32,
|
||||
dilation[0] as i32,
|
||||
dilation[1] as i32,
|
||||
padding[0] as i32,
|
||||
padding[1] as i32,
|
||||
])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -7,13 +7,13 @@ mod base;
|
|||
mod max_pool2d;
|
||||
mod max_pool2d_backward;
|
||||
mod pool2d_shader;
|
||||
pub(crate) use adaptive_pool2d_shader::*;
|
||||
pub(crate) use pool2d_shader::*;
|
||||
|
||||
pub(crate) use adaptive_avg_pool2d::*;
|
||||
pub(crate) use adaptive_avg_pool2d_backward::*;
|
||||
pub(crate) use adaptive_pool2d_shader::*;
|
||||
pub(crate) use avg_pool2d::*;
|
||||
pub(crate) use avg_pool2d_backward::*;
|
||||
pub(super) use base::*;
|
||||
pub(crate) use max_pool2d::*;
|
||||
pub(crate) use max_pool2d_backward::*;
|
||||
pub(crate) use pool2d_shader::*;
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::marker::PhantomData;
|
|||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
|
@ -152,24 +152,20 @@ pub fn reduce_dim_naive<
|
|||
output: JitTensor<R, EO, D>,
|
||||
dim: usize,
|
||||
) -> JitTensor<R, EO, D> {
|
||||
let kernel = NaiveReduceDimEagerKernel::new(dim);
|
||||
let kernel = NaiveReduceDimEagerKernel::<RD, R, EI, EO>::new(dim);
|
||||
|
||||
execute_dynamic::<R, NaiveReduceDimEagerKernel<RD, R, EI, EO>, EI>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, input.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&input.handle,
|
||||
&input.strides,
|
||||
&input.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
input.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Output { pos: 0 });
|
||||
|
||||
output
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate::{
|
|||
dialect::gpu::{
|
||||
gpu, Branch, Elem, Scope, Synchronization, Variable, Visibility, WorkgroupSize,
|
||||
},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle, Execution,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
compute::WorkGroup,
|
||||
|
@ -259,7 +259,7 @@ pub fn reduce_dim_shared<
|
|||
let divisible_shape =
|
||||
n_invocation_per_workgroup as u32 * n_input_values_per_thread == reduce_group_size as u32;
|
||||
|
||||
let kernel = SharedReduceDimEagerKernel::new(
|
||||
let kernel = SharedReduceDimEagerKernel::<RD, R, EI, EO>::new(
|
||||
dim,
|
||||
WORKGROUP_DEFAULT,
|
||||
WORKGROUP_DEFAULT,
|
||||
|
@ -267,22 +267,18 @@ pub fn reduce_dim_shared<
|
|||
divisible_shape,
|
||||
);
|
||||
|
||||
execute_dynamic::<R, SharedReduceDimEagerKernel<RD, R, EI, EO>, EI>(
|
||||
&[EagerHandle::new(
|
||||
Execution::start(kernel, input.client)
|
||||
.inputs(&[EagerHandle::<R>::new(
|
||||
&input.handle,
|
||||
&input.strides,
|
||||
&input.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
)])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Custom(grid),
|
||||
input.client,
|
||||
);
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Custom(grid));
|
||||
|
||||
output
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue