mirror of https://github.com/tracel-ai/burn.git
Refactor binary op (#2085)
This commit is contained in:
parent
88656d24ad
commit
f673721d27
|
@ -1303,7 +1303,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
|
||||
dependencies = [
|
||||
"cubecl-core",
|
||||
"cubecl-cuda",
|
||||
|
@ -1314,7 +1314,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-common"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"getrandom",
|
||||
|
@ -1328,7 +1328,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-core"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-macros",
|
||||
|
@ -1343,7 +1343,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-cuda"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
|
@ -1358,7 +1358,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-linalg"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-core",
|
||||
|
@ -1369,7 +1369,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-macros"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"proc-macro2",
|
||||
|
@ -1380,7 +1380,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-runtime"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"cubecl-common",
|
||||
|
@ -1399,7 +1399,7 @@ dependencies = [
|
|||
[[package]]
|
||||
name = "cubecl-wgpu"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"bytemuck",
|
||||
|
@ -5714,9 +5714,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.8.18"
|
||||
version = "0.8.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73b98404c41291d0a0fba7148837d26858b42e57f7abe5a4865ff39dc35d1d8c"
|
||||
checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_spanned",
|
||||
|
|
|
@ -143,8 +143,8 @@ sysinfo = "0.30.13"
|
|||
systemstat = "0.2.3"
|
||||
|
||||
### For the main burn branch. ###
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "59a2dc228b24ed1e381ccd00998f0c8745a92dfd" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "59a2dc228b24ed1e381ccd00998f0c8745a92dfd" }
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a20ac61043c5540d47259e135c0823af3dd58fe8" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a20ac61043c5540d47259e135c0823af3dd58fe8" }
|
||||
### For local development. ###
|
||||
# cubecl = { path = "../cubecl/crates/cubecl" }
|
||||
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
|
||||
|
|
|
@ -1,204 +1,306 @@
|
|||
use super::Kernel;
|
||||
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
||||
use burn_tensor::Shape;
|
||||
use cubecl::{frontend::TensorHandleRef, CubeCountSettings, Execution};
|
||||
use cubecl::{
|
||||
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
|
||||
tensor_vectorization_factor, Runtime,
|
||||
};
|
||||
|
||||
/// Creates a binary kernel.
|
||||
#[macro_export]
|
||||
macro_rules! binary {
|
||||
(
|
||||
operation: $ops:expr,
|
||||
runtime: $runtime:ty,
|
||||
input: $lhs:expr; $rhs:expr,
|
||||
elem: $elem:ty
|
||||
) => {{
|
||||
binary!(operation: $ops, compiler: <$runtime as JitRuntime>::Compiler, elem_in: $elem, elem_out: $elem);
|
||||
|
||||
$crate::kernel::binary::<
|
||||
Ops<<$runtime as Runtime>::Compiler, $elem, $elem>,
|
||||
OpsInplaceLhs<<$runtime as Runtime>::Compiler, $elem, $elem>,
|
||||
OpsInplaceRhs<<$runtime as Runtime>::Compiler, $elem, $elem>,
|
||||
$runtime,
|
||||
$elem,
|
||||
D
|
||||
>($lhs, $rhs, true, Ops::new(), OpsInplaceLhs::new(), OpsInplaceRhs::new())
|
||||
}};
|
||||
|
||||
(
|
||||
operation: $ops:expr,
|
||||
compiler: $compiler:ty,
|
||||
elem_in: $elem_in:ty,
|
||||
elem_out: $elem_out:ty
|
||||
) => {
|
||||
#[derive(new)]
|
||||
pub struct Ops<C, I, O> {
|
||||
_c: core::marker::PhantomData<C>,
|
||||
_i: core::marker::PhantomData<I>,
|
||||
_o: core::marker::PhantomData<O>,
|
||||
}
|
||||
#[derive(new)]
|
||||
pub struct OpsInplaceLhs<C, I, O> {
|
||||
_c: core::marker::PhantomData<C>,
|
||||
_i: core::marker::PhantomData<I>,
|
||||
_o: core::marker::PhantomData<O>,
|
||||
}
|
||||
#[derive(new)]
|
||||
pub struct OpsInplaceRhs<C, I, O> {
|
||||
_c: core::marker::PhantomData<C>,
|
||||
_i: core::marker::PhantomData<I>,
|
||||
_o: core::marker::PhantomData<O>,
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
fn compile<I, O>(
|
||||
settings: cubecl::KernelSettings,
|
||||
) -> cubecl::ir::KernelDefinition
|
||||
where
|
||||
I: $crate::element::JitElement,
|
||||
O: $crate::element::JitElement
|
||||
{
|
||||
let mut scope = cubecl::ir::Scope::root();
|
||||
let position = cubecl::ir::Variable::AbsolutePos;
|
||||
|
||||
let op = $ops(&mut scope, I::cube_elem(), position);
|
||||
scope.register(op);
|
||||
|
||||
let local = scope.last_local_index().unwrap().index().unwrap();
|
||||
|
||||
let lhs = cubecl::InputInfo::Array {
|
||||
item: cubecl::ir::Item::new(I::cube_elem()),
|
||||
visibility: cubecl::ir::Visibility::Read,
|
||||
};
|
||||
let rhs = cubecl::InputInfo::Array {
|
||||
item: cubecl::ir::Item::new(I::cube_elem()),
|
||||
visibility: cubecl::ir::Visibility::Read,
|
||||
};
|
||||
let out = cubecl::OutputInfo::ArrayWrite {
|
||||
item: cubecl::ir::Item::new(O::cube_elem()),
|
||||
local,
|
||||
position,
|
||||
};
|
||||
let info = cubecl::prelude::KernelExpansion {
|
||||
inputs: vec![lhs, rhs],
|
||||
outputs: vec![out],
|
||||
scope,
|
||||
};
|
||||
cubecl::prelude::KernelIntegrator::new(info).integrate(settings)
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<C, I, O> $crate::kernel::Kernel for Ops<C, I, O>
|
||||
where
|
||||
C: cubecl::Compiler,
|
||||
I: $crate::element::JitElement,
|
||||
O: $crate::element::JitElement
|
||||
{
|
||||
fn define(&self) -> cubecl::ir::KernelDefinition {
|
||||
let settings = cubecl::KernelSettings::default();
|
||||
compile::<I, O>(settings)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<C, I, O> $crate::kernel::Kernel
|
||||
for OpsInplaceLhs<C, I, O>
|
||||
where
|
||||
C: cubecl::Compiler,
|
||||
I: $crate::element::JitElement,
|
||||
O: $crate::element::JitElement
|
||||
{
|
||||
fn define(&self) -> cubecl::ir::KernelDefinition {
|
||||
let mapping = cubecl::InplaceMapping {
|
||||
pos_input: 0,
|
||||
pos_output: 0,
|
||||
};
|
||||
let settings = cubecl::KernelSettings::default()
|
||||
.inplace(vec![mapping]);
|
||||
compile::<I, O>(settings)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<C, I, O> $crate::kernel::Kernel
|
||||
for OpsInplaceRhs<C, I, O>
|
||||
where
|
||||
C: cubecl::Compiler,
|
||||
I: $crate::element::JitElement,
|
||||
O: $crate::element::JitElement
|
||||
{
|
||||
fn define(&self) -> cubecl::ir::KernelDefinition {
|
||||
let mapping = cubecl::InplaceMapping {
|
||||
pos_input: 1,
|
||||
pos_output: 0,
|
||||
};
|
||||
let settings = cubecl::KernelSettings::default()
|
||||
.inplace(vec![mapping]);
|
||||
compile::<I, O>(settings)
|
||||
}
|
||||
}
|
||||
};
|
||||
#[cube]
|
||||
pub(crate) trait BinaryOp<C: Numeric>: 'static + Send + Sync {
|
||||
/// Execute a binary operation.
|
||||
fn execute(lhs: C, rhs: C) -> C;
|
||||
}
|
||||
|
||||
/// Launch an binary operation.
|
||||
pub fn binary<Kernel, KernelInplaceLhs, KernelInplaceRhs, R: JitRuntime, E, const D: usize>(
|
||||
pub(crate) struct AddOp;
|
||||
pub(crate) struct SubOp;
|
||||
pub(crate) struct MulOp;
|
||||
pub(crate) struct DivOp;
|
||||
pub(crate) struct RemainderOp;
|
||||
pub(crate) struct PowOp;
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for AddOp {
|
||||
fn execute(lhs: N, rhs: N) -> N {
|
||||
lhs + rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for SubOp {
|
||||
fn execute(lhs: N, rhs: N) -> N {
|
||||
lhs - rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for MulOp {
|
||||
fn execute(lhs: N, rhs: N) -> N {
|
||||
lhs * rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for DivOp {
|
||||
fn execute(lhs: N, rhs: N) -> N {
|
||||
lhs / rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Numeric> BinaryOp<N> for RemainderOp {
|
||||
fn execute(lhs: N, rhs: N) -> N {
|
||||
N::rem(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
impl<N: Float> BinaryOp<N> for PowOp {
|
||||
fn execute(lhs: N, rhs: N) -> N {
|
||||
N::powf(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOp<C>>(
|
||||
input: &Tensor<C>,
|
||||
scalar: C,
|
||||
output: &mut Tensor<C>,
|
||||
) {
|
||||
let offset_output = ABSOLUTE_POS;
|
||||
|
||||
if offset_output >= output.len() {
|
||||
return;
|
||||
}
|
||||
|
||||
output[ABSOLUTE_POS] = O::execute(input[ABSOLUTE_POS], scalar);
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
pub(crate) fn kernel_binop<C: Numeric, O: BinaryOp<C>>(
|
||||
lhs: &Tensor<C>,
|
||||
rhs: &Tensor<C>,
|
||||
out: &mut Tensor<C>,
|
||||
rank: Comptime<Option<UInt>>,
|
||||
to_contiguous_lhs: Comptime<bool>,
|
||||
to_contiguous_rhs: Comptime<bool>,
|
||||
) {
|
||||
let offset_out = ABSOLUTE_POS;
|
||||
let mut offset_lhs = ABSOLUTE_POS;
|
||||
let mut offset_rhs = ABSOLUTE_POS;
|
||||
|
||||
if offset_out >= out.len() {
|
||||
return;
|
||||
}
|
||||
|
||||
if Comptime::get(to_contiguous_lhs) {
|
||||
offset_lhs = index_offset_with_layout::<C, C>(
|
||||
lhs,
|
||||
out,
|
||||
offset_out,
|
||||
UInt::new(0),
|
||||
Comptime::unwrap_or_else(rank, || out.rank()),
|
||||
Comptime::is_some(rank),
|
||||
);
|
||||
}
|
||||
|
||||
if Comptime::get(to_contiguous_rhs) {
|
||||
offset_rhs = index_offset_with_layout::<C, C>(
|
||||
rhs,
|
||||
out,
|
||||
offset_out,
|
||||
UInt::new(0),
|
||||
Comptime::unwrap_or_else(rank, || out.rank()),
|
||||
Comptime::is_some(rank),
|
||||
);
|
||||
}
|
||||
|
||||
out[offset_out] = O::execute(lhs[offset_lhs], rhs[offset_rhs]);
|
||||
}
|
||||
|
||||
pub(crate) fn launch_binop<
|
||||
const D: usize,
|
||||
R: JitRuntime,
|
||||
E: JitElement,
|
||||
O: BinaryOp<E::Primitive>,
|
||||
>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
inplace_enabled: bool,
|
||||
kernel: Kernel,
|
||||
kernel_inplace_lhs: KernelInplaceLhs,
|
||||
kernel_inplace_rhs: KernelInplaceRhs,
|
||||
) -> JitTensor<R, E, D>
|
||||
where
|
||||
Kernel: crate::kernel::Kernel,
|
||||
KernelInplaceLhs: crate::kernel::Kernel,
|
||||
KernelInplaceRhs: crate::kernel::Kernel,
|
||||
E: JitElement,
|
||||
{
|
||||
if inplace_enabled && lhs.can_mut_broadcast(&rhs) {
|
||||
Execution::start(kernel_inplace_lhs, rhs.client)
|
||||
.inputs(&[
|
||||
TensorHandleRef::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
TensorHandleRef::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
])
|
||||
.execute(CubeCountSettings::Input { pos: 0 });
|
||||
) -> JitTensor<R, E, D> {
|
||||
let vectorization_factor_lhs =
|
||||
tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, D - 1);
|
||||
let vectorization_factor_rhs =
|
||||
tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, D - 1);
|
||||
|
||||
let vectorization_factor = u8::min(vectorization_factor_lhs, vectorization_factor_rhs);
|
||||
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(rhs.shape.dims.iter())
|
||||
.enumerate()
|
||||
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||
});
|
||||
|
||||
let shape_out = Shape::new(shape_out);
|
||||
let client = lhs.client.clone();
|
||||
let num_elems = shape_out.num_elements();
|
||||
|
||||
let cube_dim = CubeDim::default();
|
||||
let cube_count =
|
||||
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
|
||||
|
||||
if lhs.can_mut_broadcast(&rhs) {
|
||||
kernel_binop::launch::<E::Primitive, O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&lhs.handle,
|
||||
&lhs.strides,
|
||||
&lhs.shape.dims,
|
||||
),
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&rhs.handle,
|
||||
&rhs.strides,
|
||||
&rhs.shape.dims,
|
||||
),
|
||||
TensorArg::alias(0),
|
||||
None,
|
||||
false,
|
||||
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
|
||||
);
|
||||
|
||||
lhs
|
||||
} else if inplace_enabled && rhs.can_mut_broadcast(&lhs) {
|
||||
Execution::start(kernel_inplace_rhs, lhs.client)
|
||||
.inputs(&[
|
||||
TensorHandleRef::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
TensorHandleRef::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
])
|
||||
.execute(CubeCountSettings::Input { pos: 1 });
|
||||
} else if rhs.can_mut_broadcast(&lhs) {
|
||||
kernel_binop::launch::<E::Primitive, O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&lhs.handle,
|
||||
&lhs.strides,
|
||||
&lhs.shape.dims,
|
||||
),
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&rhs.handle,
|
||||
&rhs.strides,
|
||||
&rhs.shape.dims,
|
||||
),
|
||||
TensorArg::alias(1),
|
||||
None,
|
||||
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
|
||||
false,
|
||||
);
|
||||
|
||||
rhs
|
||||
} else {
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(rhs.shape.dims.iter())
|
||||
.enumerate()
|
||||
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||
});
|
||||
|
||||
let shape_out = Shape::new(shape_out);
|
||||
let num_elems = shape_out.num_elements();
|
||||
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let out = JitTensor::new_contiguous(lhs.client.clone(), lhs.device, shape_out, buffer);
|
||||
let output = JitTensor::new_contiguous(lhs.client.clone(), lhs.device, shape_out, buffer);
|
||||
let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
|
||||
let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape;
|
||||
|
||||
Execution::start(kernel, lhs.client)
|
||||
.inputs(&[
|
||||
TensorHandleRef::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
TensorHandleRef::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
])
|
||||
.outputs(&[TensorHandleRef::new(
|
||||
&out.handle,
|
||||
&out.strides,
|
||||
&out.shape.dims,
|
||||
)])
|
||||
.execute(CubeCountSettings::Output { pos: 0 });
|
||||
kernel_binop::launch::<E::Primitive, O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&lhs.handle,
|
||||
&lhs.strides,
|
||||
&lhs.shape.dims,
|
||||
),
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&rhs.handle,
|
||||
&rhs.strides,
|
||||
&rhs.shape.dims,
|
||||
),
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
),
|
||||
None,
|
||||
to_contiguous_lhs,
|
||||
to_contiguous_rhs,
|
||||
);
|
||||
|
||||
out
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn launch_scalar_binop<
|
||||
const D: usize,
|
||||
R: JitRuntime,
|
||||
E: JitElement,
|
||||
O: BinaryOp<E::Primitive>,
|
||||
>(
|
||||
tensor: JitTensor<R, E, D>,
|
||||
scalar: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
// Vectorization is only enabled when the last dimension is contiguous.
|
||||
let vectorization_factor =
|
||||
tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, D - 1);
|
||||
let client = tensor.client.clone();
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
|
||||
let cube_dim = CubeDim::default();
|
||||
let cube_count =
|
||||
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
|
||||
|
||||
if tensor.can_mut() {
|
||||
kernel_scalar_binop::launch::<E::Primitive, O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ScalarArg::new(scalar),
|
||||
TensorArg::alias(0),
|
||||
);
|
||||
|
||||
tensor
|
||||
} else {
|
||||
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = JitTensor::new(
|
||||
tensor.client.clone(),
|
||||
buffer,
|
||||
tensor.shape.clone(),
|
||||
tensor.device,
|
||||
tensor.strides,
|
||||
);
|
||||
|
||||
kernel_scalar_binop::launch::<E::Primitive, O, R>(
|
||||
&client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ScalarArg::new(scalar),
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
),
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
|
|
@ -122,7 +122,7 @@ pub(crate) fn launch_cmp<
|
|||
let vectorization_factor_lhs =
|
||||
tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, D - 1);
|
||||
let vectorization_factor_rhs =
|
||||
tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, D - 1);
|
||||
tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, D - 1);
|
||||
|
||||
let vectorization_factor = u8::min(vectorization_factor_lhs, vectorization_factor_rhs);
|
||||
|
||||
|
@ -163,9 +163,9 @@ pub(crate) fn launch_cmp<
|
|||
&rhs.shape.dims,
|
||||
),
|
||||
TensorArg::alias(0),
|
||||
Some(UInt::new(D as u32)),
|
||||
None,
|
||||
false,
|
||||
!rhs.is_contiguous(),
|
||||
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
|
||||
);
|
||||
|
||||
JitTensor::new(lhs.client, lhs.handle, lhs.shape, lhs.device, lhs.strides)
|
||||
|
@ -187,17 +187,17 @@ pub(crate) fn launch_cmp<
|
|||
&rhs.shape.dims,
|
||||
),
|
||||
TensorArg::alias(1),
|
||||
Some(UInt::new(D as u32)),
|
||||
!lhs.is_contiguous(),
|
||||
None,
|
||||
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
|
||||
false,
|
||||
);
|
||||
|
||||
JitTensor::new(rhs.client, rhs.handle, rhs.shape, rhs.device, rhs.strides)
|
||||
} else {
|
||||
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let to_contiguous_lhs = !lhs.is_contiguous();
|
||||
let to_contiguous_rhs = !rhs.is_contiguous();
|
||||
let output = JitTensor::new_contiguous(lhs.client.clone(), lhs.device, shape_out, buffer);
|
||||
let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
|
||||
let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape;
|
||||
|
||||
kernel_cmp::launch::<E::Primitive, O, R>(
|
||||
&client,
|
||||
|
@ -221,7 +221,7 @@ pub(crate) fn launch_cmp<
|
|||
&output.strides,
|
||||
&output.shape.dims,
|
||||
),
|
||||
Some(UInt::new(D as u32)),
|
||||
None,
|
||||
to_contiguous_lhs,
|
||||
to_contiguous_rhs,
|
||||
);
|
||||
|
|
|
@ -7,7 +7,7 @@ mod index;
|
|||
mod mask;
|
||||
mod unary;
|
||||
|
||||
pub use binary::*;
|
||||
pub(crate) use binary::*;
|
||||
pub use cast::*;
|
||||
pub use contiguous::*;
|
||||
pub use mask::*;
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use crate::kernel::{launch_unary, unary_op, UnaryOp};
|
||||
use crate::{binary, JitRuntime};
|
||||
use crate::kernel::{
|
||||
launch_binop, launch_scalar_binop, AddOp, DivOp, MulOp, PowOp, RemainderOp, SubOp,
|
||||
};
|
||||
use crate::{element::JitElement, tensor::JitTensor};
|
||||
use crate::{FloatElement, JitRuntime};
|
||||
use burn_tensor::{ElementConversion, Shape};
|
||||
use cubecl::client::ComputeClient;
|
||||
use cubecl::ir::{BinaryOperator, Elem, Operator, Scope, Variable};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
use cubecl::{tensor_vectorization_factor, Runtime};
|
||||
|
||||
|
@ -106,151 +107,68 @@ pub fn add<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
binary!(
|
||||
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Add(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem, position),
|
||||
rhs: scope.read_array(1, elem, position),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
launch_binop::<D, R, E, AddOp>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn add_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
unary_op!(numeric(lhs, rhs) => |context, lhs, rhs| {
|
||||
#[cube]
|
||||
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
|
||||
lhs + rhs
|
||||
}
|
||||
execute::__expand::<C>(context, lhs, rhs)
|
||||
})
|
||||
launch_scalar_binop::<D, R, E, AddOp>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn sub<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
binary!(
|
||||
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Sub(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem, position),
|
||||
rhs: scope.read_array(1, elem, position),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
launch_binop::<D, R, E, SubOp>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn sub_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
unary_op!(numeric(lhs, rhs) => |context, lhs, rhs| {
|
||||
#[cube]
|
||||
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
|
||||
lhs - rhs
|
||||
}
|
||||
execute::__expand::<C>(context, lhs, rhs)
|
||||
})
|
||||
launch_scalar_binop::<D, R, E, SubOp>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn mul<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
binary!(
|
||||
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Mul(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem, position),
|
||||
rhs: scope.read_array(1, elem, position),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
launch_binop::<D, R, E, MulOp>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn mul_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
unary_op!(numeric(lhs, rhs) => |context, lhs, rhs| {
|
||||
#[cube]
|
||||
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
|
||||
lhs * rhs
|
||||
}
|
||||
execute::__expand::<C>(context, lhs, rhs)
|
||||
})
|
||||
launch_scalar_binop::<D, R, E, MulOp>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn div<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
binary!(
|
||||
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Div(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem, position),
|
||||
rhs: scope.read_array(1, elem, position),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
launch_binop::<D, R, E, DivOp>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn div_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
unary_op!(numeric(lhs, rhs) => |context, lhs, rhs| {
|
||||
#[cube]
|
||||
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
|
||||
lhs / rhs
|
||||
}
|
||||
execute::__expand::<C>(context, lhs, rhs)
|
||||
})
|
||||
launch_scalar_binop::<D, R, E, DivOp>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn remainder_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let shape = lhs.shape.clone();
|
||||
let device = lhs.device.clone();
|
||||
|
||||
let rhs_tensor = full::<R, E, D>(shape, &device, rhs);
|
||||
|
||||
binary!(
|
||||
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Remainder(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem, position),
|
||||
rhs: scope.read_array(1, elem, position),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs_tensor,
|
||||
elem: E
|
||||
)
|
||||
launch_scalar_binop::<D, R, E, RemainderOp>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn pow<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
pub fn pow<R: JitRuntime, E: FloatElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
binary!(
|
||||
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Powf(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem, position),
|
||||
rhs: scope.read_array(1, elem, position),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
launch_binop::<D, R, E, PowOp>(lhs, rhs)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue