Select kernel from CPA to CubeCL (#2168)

---------

Co-authored-by: louisfd <louisfd94@gmail.com>
This commit is contained in:
mepatrick73 2024-08-27 15:17:58 -04:00 committed by GitHub
parent a600a7b54e
commit 795201dcfc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 121 deletions

View File

@ -9,7 +9,7 @@ mod slice_assign;
pub use flip::*;
pub use repeat_dim::*;
pub use select::*;
pub(crate) use select::*;
pub(crate) use select_assign::*;
pub use slice::*;
pub use slice_assign::*;

View File

@ -1,118 +1,33 @@
use crate::{
element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime,
};
use cubecl::{
cpa,
frontend::TensorHandleRef,
ir::{Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
use std::marker::PhantomData;
use cubecl::prelude::*;
use cubecl::{calculate_cube_count_elemwise, CubeDim};
#[derive(new)]
struct SelectEagerKernel<R: JitRuntime, E: JitElement> {
dim: usize,
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
#[cube(launch_unchecked)]
fn select_kernel<T: Numeric, I: Numeric>(
input: &Tensor<T>,
indices: &Tensor<I>,
output: &mut Tensor<T>,
dim: UInt,
) {
if ABSOLUTE_POS >= output.len() {
return;
}
pub struct SelectComputeShader {
input: Variable,
indices: Variable,
output: Variable,
dim: usize,
let mut offset_input = UInt::new(0);
for i in range(0u32, output.rank(), Comptime::new(false)) {
let mut offset_local = ABSOLUTE_POS / output.stride(i) % output.shape(i);
if i == dim {
offset_local = UInt::cast_from(indices[offset_local]);
}
impl SelectComputeShader {
pub fn expand(self, scope: &mut Scope) {
let input = self.input;
let indices = self.indices;
let output = self.output;
let id = Variable::AbsolutePos;
let offset_input = scope.zero(Elem::UInt);
cpa!(
scope,
range(0u32, Variable::Rank).for_each(|i, scope| {
let stride_input = scope.create_local(Elem::UInt);
let stride_output = scope.create_local(Elem::UInt);
let shape_output = scope.create_local(Elem::UInt);
cpa!(scope, stride_input = stride(input, i));
cpa!(scope, stride_output = stride(output, i));
cpa!(scope, shape_output = shape(output, i));
let offset_local = scope.create_local(Elem::UInt);
cpa!(scope, offset_local = id / stride_output);
cpa!(scope, offset_local = offset_local % shape_output);
let dim_index = scope.create_local(Elem::Bool);
cpa!(scope, dim_index = i == self.dim);
cpa!(scope, if(dim_index).then(|scope| {
cpa!(scope, offset_local = indices[offset_local]);
cpa!(scope, offset_local = offset_local * stride_input);
}).else(|scope| {
cpa!(scope, offset_local = offset_local * stride_input);
}));
cpa!(scope, offset_input += offset_local);
})
);
let value = scope.create_local(input.item());
cpa!(scope, value = input[offset_input]);
cpa!(scope, output[id] = value);
}
offset_input += offset_local * input.stride(i);
}
impl<R: JitRuntime, E: JitElement> Kernel for SelectEagerKernel<R, E> {
fn define(&self) -> KernelDefinition {
let mut scope = Scope::root();
let item = E::cube_elem().into();
let item_indices: Item = Elem::Int(IntKind::I32).into();
let input = Variable::GlobalInputArray { id: 0, item };
let indices = Variable::GlobalInputArray {
id: 1,
item: item_indices,
};
let output = Variable::GlobalOutputArray { id: 0, item };
scope.write_global_custom(output);
SelectComputeShader {
input,
indices,
output,
dim: self.dim,
}
.expand(&mut scope);
let input = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let indices = InputInfo::Array {
item: item_indices,
visibility: Visibility::Read,
};
let output = OutputInfo::Array { item };
let info = KernelExpansion {
inputs: vec![input, indices],
outputs: vec![output],
scope,
};
let settings = KernelSettings::default();
KernelIntegrator::new(info).integrate(settings)
}
fn id(&self) -> cubecl::KernelId {
cubecl::KernelId::new::<Self>().info(self.dim)
}
output[ABSOLUTE_POS] = input[offset_input];
}
pub(crate) fn select<R: JitRuntime, E: JitElement, I: JitElement, const D: usize>(
@ -122,26 +37,25 @@ pub(crate) fn select<R: JitRuntime, E: JitElement, I: JitElement, const D: usize
) -> JitTensor<R, E, D> {
let mut shape_output = tensor.shape.clone();
shape_output.dims[dim] = indices.shape.dims[0];
let total_elem = shape_output.num_elements();
let output = empty_device(tensor.client.clone(), tensor.device.clone(), shape_output);
let kernel = SelectEagerKernel::<R, E>::new(dim);
let num_elems = indices.shape.dims[0];
let mut shapes = [1; D];
let mut strides = [num_elems; D];
shapes[D - 1] = num_elems;
strides[D - 1] = 1;
Execution::start(kernel, tensor.client.clone())
.inputs(&[
tensor.as_handle_ref(),
// This is a current hacks because the info buffer that contains the strides and shapes is
// hardcoded to only contains information about tensors of the same rank. However, since
// we don't rely on the shape and stride of the indices tensors, it doesn't matter
// which value we put, it just needs to be of the same rank.
unsafe { TensorHandleRef::from_raw_parts(&indices.handle, &strides, &shapes) },
])
.outputs(&[output.as_handle_ref()])
.execute(CubeCountSettings::Output { pos: 0 });
let dummy_array = [1; D];
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim);
unsafe {
select_kernel::launch_unchecked::<E::Primitive, I::Primitive, R>(
&tensor.client,
cube_count,
cube_dim,
tensor.as_tensor_arg(1),
// Ignore shape and stride
TensorArg::from_raw_parts(&indices.handle, &dummy_array, &dummy_array, 1),
output.as_tensor_arg(1),
ScalarArg::new(dim as u32),
)
};
output
}