Refactor binary op (#2085)

This commit is contained in:
Nathaniel Simard 2024-07-31 16:18:21 -04:00 committed by GitHub
parent 88656d24ad
commit f673721d27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 325 additions and 305 deletions

20
Cargo.lock generated
View File

@ -1303,7 +1303,7 @@ dependencies = [
[[package]]
name = "cubecl"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
dependencies = [
"cubecl-core",
"cubecl-cuda",
@ -1314,7 +1314,7 @@ dependencies = [
[[package]]
name = "cubecl-common"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
dependencies = [
"derive-new",
"getrandom",
@ -1328,7 +1328,7 @@ dependencies = [
[[package]]
name = "cubecl-core"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
dependencies = [
"bytemuck",
"cubecl-macros",
@ -1343,7 +1343,7 @@ dependencies = [
[[package]]
name = "cubecl-cuda"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
dependencies = [
"bytemuck",
"cubecl-common",
@ -1358,7 +1358,7 @@ dependencies = [
[[package]]
name = "cubecl-linalg"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
dependencies = [
"bytemuck",
"cubecl-core",
@ -1369,7 +1369,7 @@ dependencies = [
[[package]]
name = "cubecl-macros"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
dependencies = [
"derive-new",
"proc-macro2",
@ -1380,7 +1380,7 @@ dependencies = [
[[package]]
name = "cubecl-runtime"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
dependencies = [
"async-channel",
"cubecl-common",
@ -1399,7 +1399,7 @@ dependencies = [
[[package]]
name = "cubecl-wgpu"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=59a2dc228b24ed1e381ccd00998f0c8745a92dfd#59a2dc228b24ed1e381ccd00998f0c8745a92dfd"
source = "git+https://github.com/tracel-ai/cubecl?rev=a20ac61043c5540d47259e135c0823af3dd58fe8#a20ac61043c5540d47259e135c0823af3dd58fe8"
dependencies = [
"async-channel",
"bytemuck",
@ -5714,9 +5714,9 @@ dependencies = [
[[package]]
name = "toml"
version = "0.8.18"
version = "0.8.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73b98404c41291d0a0fba7148837d26858b42e57f7abe5a4865ff39dc35d1d8c"
checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e"
dependencies = [
"serde",
"serde_spanned",

View File

@ -143,8 +143,8 @@ sysinfo = "0.30.13"
systemstat = "0.2.3"
### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "59a2dc228b24ed1e381ccd00998f0c8745a92dfd" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "59a2dc228b24ed1e381ccd00998f0c8745a92dfd" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a20ac61043c5540d47259e135c0823af3dd58fe8" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a20ac61043c5540d47259e135c0823af3dd58fe8" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }

View File

@ -1,204 +1,306 @@
use super::Kernel;
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use burn_tensor::Shape;
use cubecl::{frontend::TensorHandleRef, CubeCountSettings, Execution};
use cubecl::{
calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
tensor_vectorization_factor, Runtime,
};
/// Creates a binary kernel.
#[macro_export]
macro_rules! binary {
(
operation: $ops:expr,
runtime: $runtime:ty,
input: $lhs:expr; $rhs:expr,
elem: $elem:ty
) => {{
binary!(operation: $ops, compiler: <$runtime as JitRuntime>::Compiler, elem_in: $elem, elem_out: $elem);
$crate::kernel::binary::<
Ops<<$runtime as Runtime>::Compiler, $elem, $elem>,
OpsInplaceLhs<<$runtime as Runtime>::Compiler, $elem, $elem>,
OpsInplaceRhs<<$runtime as Runtime>::Compiler, $elem, $elem>,
$runtime,
$elem,
D
>($lhs, $rhs, true, Ops::new(), OpsInplaceLhs::new(), OpsInplaceRhs::new())
}};
(
operation: $ops:expr,
compiler: $compiler:ty,
elem_in: $elem_in:ty,
elem_out: $elem_out:ty
) => {
#[derive(new)]
pub struct Ops<C, I, O> {
_c: core::marker::PhantomData<C>,
_i: core::marker::PhantomData<I>,
_o: core::marker::PhantomData<O>,
}
#[derive(new)]
pub struct OpsInplaceLhs<C, I, O> {
_c: core::marker::PhantomData<C>,
_i: core::marker::PhantomData<I>,
_o: core::marker::PhantomData<O>,
}
#[derive(new)]
pub struct OpsInplaceRhs<C, I, O> {
_c: core::marker::PhantomData<C>,
_i: core::marker::PhantomData<I>,
_o: core::marker::PhantomData<O>,
}
#[allow(clippy::redundant_closure_call)]
fn compile<I, O>(
settings: cubecl::KernelSettings,
) -> cubecl::ir::KernelDefinition
where
I: $crate::element::JitElement,
O: $crate::element::JitElement
{
let mut scope = cubecl::ir::Scope::root();
let position = cubecl::ir::Variable::AbsolutePos;
let op = $ops(&mut scope, I::cube_elem(), position);
scope.register(op);
let local = scope.last_local_index().unwrap().index().unwrap();
let lhs = cubecl::InputInfo::Array {
item: cubecl::ir::Item::new(I::cube_elem()),
visibility: cubecl::ir::Visibility::Read,
};
let rhs = cubecl::InputInfo::Array {
item: cubecl::ir::Item::new(I::cube_elem()),
visibility: cubecl::ir::Visibility::Read,
};
let out = cubecl::OutputInfo::ArrayWrite {
item: cubecl::ir::Item::new(O::cube_elem()),
local,
position,
};
let info = cubecl::prelude::KernelExpansion {
inputs: vec![lhs, rhs],
outputs: vec![out],
scope,
};
cubecl::prelude::KernelIntegrator::new(info).integrate(settings)
}
#[allow(clippy::redundant_closure_call)]
impl<C, I, O> $crate::kernel::Kernel for Ops<C, I, O>
where
C: cubecl::Compiler,
I: $crate::element::JitElement,
O: $crate::element::JitElement
{
fn define(&self) -> cubecl::ir::KernelDefinition {
let settings = cubecl::KernelSettings::default();
compile::<I, O>(settings)
}
}
#[allow(clippy::redundant_closure_call)]
impl<C, I, O> $crate::kernel::Kernel
for OpsInplaceLhs<C, I, O>
where
C: cubecl::Compiler,
I: $crate::element::JitElement,
O: $crate::element::JitElement
{
fn define(&self) -> cubecl::ir::KernelDefinition {
let mapping = cubecl::InplaceMapping {
pos_input: 0,
pos_output: 0,
};
let settings = cubecl::KernelSettings::default()
.inplace(vec![mapping]);
compile::<I, O>(settings)
}
}
#[allow(clippy::redundant_closure_call)]
impl<C, I, O> $crate::kernel::Kernel
for OpsInplaceRhs<C, I, O>
where
C: cubecl::Compiler,
I: $crate::element::JitElement,
O: $crate::element::JitElement
{
fn define(&self) -> cubecl::ir::KernelDefinition {
let mapping = cubecl::InplaceMapping {
pos_input: 1,
pos_output: 0,
};
let settings = cubecl::KernelSettings::default()
.inplace(vec![mapping]);
compile::<I, O>(settings)
}
}
};
#[cube]
pub(crate) trait BinaryOp<C: Numeric>: 'static + Send + Sync {
/// Execute a binary operation.
fn execute(lhs: C, rhs: C) -> C;
}
/// Launch an binary operation.
pub fn binary<Kernel, KernelInplaceLhs, KernelInplaceRhs, R: JitRuntime, E, const D: usize>(
pub(crate) struct AddOp;
pub(crate) struct SubOp;
pub(crate) struct MulOp;
pub(crate) struct DivOp;
pub(crate) struct RemainderOp;
pub(crate) struct PowOp;
#[cube]
impl<N: Numeric> BinaryOp<N> for AddOp {
fn execute(lhs: N, rhs: N) -> N {
lhs + rhs
}
}
#[cube]
impl<N: Numeric> BinaryOp<N> for SubOp {
fn execute(lhs: N, rhs: N) -> N {
lhs - rhs
}
}
#[cube]
impl<N: Numeric> BinaryOp<N> for MulOp {
fn execute(lhs: N, rhs: N) -> N {
lhs * rhs
}
}
#[cube]
impl<N: Numeric> BinaryOp<N> for DivOp {
fn execute(lhs: N, rhs: N) -> N {
lhs / rhs
}
}
#[cube]
impl<N: Numeric> BinaryOp<N> for RemainderOp {
fn execute(lhs: N, rhs: N) -> N {
N::rem(lhs, rhs)
}
}
#[cube]
impl<N: Float> BinaryOp<N> for PowOp {
fn execute(lhs: N, rhs: N) -> N {
N::powf(lhs, rhs)
}
}
#[cube(launch)]
pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOp<C>>(
input: &Tensor<C>,
scalar: C,
output: &mut Tensor<C>,
) {
let offset_output = ABSOLUTE_POS;
if offset_output >= output.len() {
return;
}
output[ABSOLUTE_POS] = O::execute(input[ABSOLUTE_POS], scalar);
}
#[cube(launch)]
pub(crate) fn kernel_binop<C: Numeric, O: BinaryOp<C>>(
lhs: &Tensor<C>,
rhs: &Tensor<C>,
out: &mut Tensor<C>,
rank: Comptime<Option<UInt>>,
to_contiguous_lhs: Comptime<bool>,
to_contiguous_rhs: Comptime<bool>,
) {
let offset_out = ABSOLUTE_POS;
let mut offset_lhs = ABSOLUTE_POS;
let mut offset_rhs = ABSOLUTE_POS;
if offset_out >= out.len() {
return;
}
if Comptime::get(to_contiguous_lhs) {
offset_lhs = index_offset_with_layout::<C, C>(
lhs,
out,
offset_out,
UInt::new(0),
Comptime::unwrap_or_else(rank, || out.rank()),
Comptime::is_some(rank),
);
}
if Comptime::get(to_contiguous_rhs) {
offset_rhs = index_offset_with_layout::<C, C>(
rhs,
out,
offset_out,
UInt::new(0),
Comptime::unwrap_or_else(rank, || out.rank()),
Comptime::is_some(rank),
);
}
out[offset_out] = O::execute(lhs[offset_lhs], rhs[offset_rhs]);
}
pub(crate) fn launch_binop<
const D: usize,
R: JitRuntime,
E: JitElement,
O: BinaryOp<E::Primitive>,
>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
inplace_enabled: bool,
kernel: Kernel,
kernel_inplace_lhs: KernelInplaceLhs,
kernel_inplace_rhs: KernelInplaceRhs,
) -> JitTensor<R, E, D>
where
Kernel: crate::kernel::Kernel,
KernelInplaceLhs: crate::kernel::Kernel,
KernelInplaceRhs: crate::kernel::Kernel,
E: JitElement,
{
if inplace_enabled && lhs.can_mut_broadcast(&rhs) {
Execution::start(kernel_inplace_lhs, rhs.client)
.inputs(&[
TensorHandleRef::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
TensorHandleRef::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
])
.execute(CubeCountSettings::Input { pos: 0 });
) -> JitTensor<R, E, D> {
let vectorization_factor_lhs =
tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, D - 1);
let vectorization_factor_rhs =
tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, D - 1);
let vectorization_factor = u8::min(vectorization_factor_lhs, vectorization_factor_rhs);
let mut shape_out = [0; D];
lhs.shape
.dims
.iter()
.zip(rhs.shape.dims.iter())
.enumerate()
.for_each(|(index, (dim_lhs, dim_rhs))| {
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
});
let shape_out = Shape::new(shape_out);
let client = lhs.client.clone();
let num_elems = shape_out.num_elements();
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
if lhs.can_mut_broadcast(&rhs) {
kernel_binop::launch::<E::Primitive, O, R>(
&client,
cube_count,
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&lhs.handle,
&lhs.strides,
&lhs.shape.dims,
),
TensorArg::vectorized(
vectorization_factor,
&rhs.handle,
&rhs.strides,
&rhs.shape.dims,
),
TensorArg::alias(0),
None,
false,
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
);
lhs
} else if inplace_enabled && rhs.can_mut_broadcast(&lhs) {
Execution::start(kernel_inplace_rhs, lhs.client)
.inputs(&[
TensorHandleRef::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
TensorHandleRef::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
])
.execute(CubeCountSettings::Input { pos: 1 });
} else if rhs.can_mut_broadcast(&lhs) {
kernel_binop::launch::<E::Primitive, O, R>(
&client,
cube_count,
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&lhs.handle,
&lhs.strides,
&lhs.shape.dims,
),
TensorArg::vectorized(
vectorization_factor,
&rhs.handle,
&rhs.strides,
&rhs.shape.dims,
),
TensorArg::alias(1),
None,
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
false,
);
rhs
} else {
let mut shape_out = [0; D];
lhs.shape
.dims
.iter()
.zip(rhs.shape.dims.iter())
.enumerate()
.for_each(|(index, (dim_lhs, dim_rhs))| {
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
});
let shape_out = Shape::new(shape_out);
let num_elems = shape_out.num_elements();
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
let out = JitTensor::new_contiguous(lhs.client.clone(), lhs.device, shape_out, buffer);
let output = JitTensor::new_contiguous(lhs.client.clone(), lhs.device, shape_out, buffer);
let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape;
Execution::start(kernel, lhs.client)
.inputs(&[
TensorHandleRef::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
TensorHandleRef::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
])
.outputs(&[TensorHandleRef::new(
&out.handle,
&out.strides,
&out.shape.dims,
)])
.execute(CubeCountSettings::Output { pos: 0 });
kernel_binop::launch::<E::Primitive, O, R>(
&client,
cube_count,
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&lhs.handle,
&lhs.strides,
&lhs.shape.dims,
),
TensorArg::vectorized(
vectorization_factor,
&rhs.handle,
&rhs.strides,
&rhs.shape.dims,
),
TensorArg::vectorized(
vectorization_factor,
&output.handle,
&output.strides,
&output.shape.dims,
),
None,
to_contiguous_lhs,
to_contiguous_rhs,
);
out
output
}
}
pub(crate) fn launch_scalar_binop<
const D: usize,
R: JitRuntime,
E: JitElement,
O: BinaryOp<E::Primitive>,
>(
tensor: JitTensor<R, E, D>,
scalar: E,
) -> JitTensor<R, E, D> {
// Vectorization is only enabled when the last dimension is contiguous.
let vectorization_factor =
tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, D - 1);
let client = tensor.client.clone();
let num_elems = tensor.shape.num_elements();
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim);
if tensor.can_mut() {
kernel_scalar_binop::launch::<E::Primitive, O, R>(
&client,
cube_count,
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ScalarArg::new(scalar),
TensorArg::alias(0),
);
tensor
} else {
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
let output = JitTensor::new(
tensor.client.clone(),
buffer,
tensor.shape.clone(),
tensor.device,
tensor.strides,
);
kernel_scalar_binop::launch::<E::Primitive, O, R>(
&client,
cube_count,
CubeDim::default(),
TensorArg::vectorized(
vectorization_factor,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ScalarArg::new(scalar),
TensorArg::vectorized(
vectorization_factor,
&output.handle,
&output.strides,
&output.shape.dims,
),
);
output
}
}

View File

@ -122,7 +122,7 @@ pub(crate) fn launch_cmp<
let vectorization_factor_lhs =
tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, D - 1);
let vectorization_factor_rhs =
tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, D - 1);
tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, D - 1);
let vectorization_factor = u8::min(vectorization_factor_lhs, vectorization_factor_rhs);
@ -163,9 +163,9 @@ pub(crate) fn launch_cmp<
&rhs.shape.dims,
),
TensorArg::alias(0),
Some(UInt::new(D as u32)),
None,
false,
!rhs.is_contiguous(),
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
);
JitTensor::new(lhs.client, lhs.handle, lhs.shape, lhs.device, lhs.strides)
@ -187,17 +187,17 @@ pub(crate) fn launch_cmp<
&rhs.shape.dims,
),
TensorArg::alias(1),
Some(UInt::new(D as u32)),
!lhs.is_contiguous(),
None,
rhs.strides != lhs.strides || rhs.shape != lhs.shape,
false,
);
JitTensor::new(rhs.client, rhs.handle, rhs.shape, rhs.device, rhs.strides)
} else {
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
let to_contiguous_lhs = !lhs.is_contiguous();
let to_contiguous_rhs = !rhs.is_contiguous();
let output = JitTensor::new_contiguous(lhs.client.clone(), lhs.device, shape_out, buffer);
let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape;
kernel_cmp::launch::<E::Primitive, O, R>(
&client,
@ -221,7 +221,7 @@ pub(crate) fn launch_cmp<
&output.strides,
&output.shape.dims,
),
Some(UInt::new(D as u32)),
None,
to_contiguous_lhs,
to_contiguous_rhs,
);

View File

@ -7,7 +7,7 @@ mod index;
mod mask;
mod unary;
pub use binary::*;
pub(crate) use binary::*;
pub use cast::*;
pub use contiguous::*;
pub use mask::*;

View File

@ -1,9 +1,10 @@
use crate::kernel::{launch_unary, unary_op, UnaryOp};
use crate::{binary, JitRuntime};
use crate::kernel::{
launch_binop, launch_scalar_binop, AddOp, DivOp, MulOp, PowOp, RemainderOp, SubOp,
};
use crate::{element::JitElement, tensor::JitTensor};
use crate::{FloatElement, JitRuntime};
use burn_tensor::{ElementConversion, Shape};
use cubecl::client::ComputeClient;
use cubecl::ir::{BinaryOperator, Elem, Operator, Scope, Variable};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use cubecl::{tensor_vectorization_factor, Runtime};
@ -106,151 +107,68 @@ pub fn add<R: JitRuntime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
binary!(
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Add(BinaryOperator {
lhs: scope.read_array(0, elem, position),
rhs: scope.read_array(1, elem, position),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
elem: E
)
launch_binop::<D, R, E, AddOp>(lhs, rhs)
}
pub fn add_scalar<R: JitRuntime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: E,
) -> JitTensor<R, E, D> {
unary_op!(numeric(lhs, rhs) => |context, lhs, rhs| {
#[cube]
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
lhs + rhs
}
execute::__expand::<C>(context, lhs, rhs)
})
launch_scalar_binop::<D, R, E, AddOp>(lhs, rhs)
}
pub fn sub<R: JitRuntime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
binary!(
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Sub(BinaryOperator {
lhs: scope.read_array(0, elem, position),
rhs: scope.read_array(1, elem, position),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
elem: E
)
launch_binop::<D, R, E, SubOp>(lhs, rhs)
}
pub fn sub_scalar<R: JitRuntime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: E,
) -> JitTensor<R, E, D> {
unary_op!(numeric(lhs, rhs) => |context, lhs, rhs| {
#[cube]
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
lhs - rhs
}
execute::__expand::<C>(context, lhs, rhs)
})
launch_scalar_binop::<D, R, E, SubOp>(lhs, rhs)
}
pub fn mul<R: JitRuntime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
binary!(
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Mul(BinaryOperator {
lhs: scope.read_array(0, elem, position),
rhs: scope.read_array(1, elem, position),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
elem: E
)
launch_binop::<D, R, E, MulOp>(lhs, rhs)
}
pub fn mul_scalar<R: JitRuntime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: E,
) -> JitTensor<R, E, D> {
unary_op!(numeric(lhs, rhs) => |context, lhs, rhs| {
#[cube]
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
lhs * rhs
}
execute::__expand::<C>(context, lhs, rhs)
})
launch_scalar_binop::<D, R, E, MulOp>(lhs, rhs)
}
pub fn div<R: JitRuntime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
binary!(
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Div(BinaryOperator {
lhs: scope.read_array(0, elem, position),
rhs: scope.read_array(1, elem, position),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
elem: E
)
launch_binop::<D, R, E, DivOp>(lhs, rhs)
}
pub fn div_scalar<R: JitRuntime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: E,
) -> JitTensor<R, E, D> {
unary_op!(numeric(lhs, rhs) => |context, lhs, rhs| {
#[cube]
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
lhs / rhs
}
execute::__expand::<C>(context, lhs, rhs)
})
launch_scalar_binop::<D, R, E, DivOp>(lhs, rhs)
}
pub fn remainder_scalar<R: JitRuntime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: E,
) -> JitTensor<R, E, D> {
let shape = lhs.shape.clone();
let device = lhs.device.clone();
let rhs_tensor = full::<R, E, D>(shape, &device, rhs);
binary!(
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Remainder(BinaryOperator {
lhs: scope.read_array(0, elem, position),
rhs: scope.read_array(1, elem, position),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs_tensor,
elem: E
)
launch_scalar_binop::<D, R, E, RemainderOp>(lhs, rhs)
}
pub fn pow<R: JitRuntime, E: JitElement, const D: usize>(
pub fn pow<R: JitRuntime, E: FloatElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
binary!(
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Powf(BinaryOperator {
lhs: scope.read_array(0, elem, position),
rhs: scope.read_array(1, elem, position),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
elem: E
)
launch_binop::<D, R, E, PowOp>(lhs, rhs)
}