Refactor/cube/vectorization (#1781)

This commit is contained in:
Louis Fortier-Dubois 2024-05-19 13:20:55 -04:00 committed by GitHub
parent 499ff0dd26
commit 76fe0ed881
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
49 changed files with 433 additions and 277 deletions

View File

@ -239,6 +239,15 @@ impl CodeAnalysisBuilder {
}
syn::Expr::Break(_) => {}
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Array(expr) => {
for element in expr.elems.iter() {
match element {
syn::Expr::Lit(_) => {}
_ => todo!("Analysis: only array of literals is supported"),
}
}
}
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
_ => todo!("Analysis: unsupported expr {expr:?}"),
}
}

View File

@ -6,7 +6,10 @@ use super::{
branch::{codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_while_loop},
function::{codegen_call, codegen_closure, codegen_expr_method_call},
operation::codegen_binary,
variable::{codegen_assign, codegen_index, codegen_lit, codegen_local, codegen_path_rhs},
variable::{
codegen_array_lit, codegen_assign, codegen_index, codegen_lit, codegen_local,
codegen_path_rhs,
},
};
/// Codegen for a statement (generally one line)
@ -59,6 +62,15 @@ pub(crate) fn codegen_expr_block(
codegen_block(&block.block, loop_level, variable_analyses)
}
pub(crate) fn codegen_ref(
reference: &syn::ExprReference,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
let inner = codegen_expr(&reference.expr, loop_level, variable_analyses);
quote::quote! { & #inner }
}
/// Codegen for expressions
/// There are many variants of expression, treated differently
pub(crate) fn codegen_expr(
@ -84,6 +96,8 @@ pub(crate) fn codegen_expr(
syn::Expr::MethodCall(call) => codegen_expr_method_call(call),
syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses),
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses),
syn::Expr::Array(array) => codegen_array_lit(array),
syn::Expr::Reference(reference) => codegen_ref(reference, loop_level, variable_analyses),
_ => panic!("Codegen: Unsupported {:?}", expr),
}
}

View File

@ -34,12 +34,7 @@ pub(crate) fn codegen_closure(
}
/// Codegen for a function call
/// Supports:
/// func()
/// func::<T>()
/// T::func()
///
/// Should map:
/// Maps
/// [A[::<...>]?::]^* func[::<...>] (args)
/// to
/// [A[::<...>]?::]^* func_expand[::<...>] (context, args)

View File

@ -19,6 +19,19 @@ pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream {
}
}
/// Codegen for arrays of literals
pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream {
let mut tokens = quote::quote! {};
for element in array.elems.iter() {
let token = match element {
syn::Expr::Lit(lit) => codegen_lit(lit),
_ => todo!("Codegen: Only arrays of literals are supported"),
};
tokens.extend(quote::quote! { #token, });
}
quote::quote! { [ #tokens ] }
}
/// Codegen for a local declaration (let ...)
/// Supports:
/// let x = ...

View File

@ -81,12 +81,7 @@ impl core::fmt::Display for CompilationSettings {
}
match self.vectorization {
Some(vectorization) => match vectorization {
Vectorization::Vec4 => f.write_str("v4"),
Vectorization::Vec3 => f.write_str("v3"),
Vectorization::Vec2 => f.write_str("v2"),
Vectorization::Scalar => f.write_str("v1"),
}?,
Some(vectorization) => f.write_fmt(format_args!("v{}", vectorization))?,
None => f.write_str("vn")?,
};
@ -154,7 +149,7 @@ impl InputInfo {
item,
visibility: _,
} => *item,
InputInfo::Scalar { elem, size: _ } => Item::Scalar(*elem),
InputInfo::Scalar { elem, size: _ } => Item::new(*elem),
}
}
}
@ -252,7 +247,7 @@ impl Compilation {
named.push((
"info".to_string(),
Binding {
item: Item::Scalar(Elem::UInt),
item: Item::new(Elem::UInt),
visibility: Visibility::Read,
location: Location::Storage,
size: None, // We avoid putting the length here since it will force a new kernel
@ -300,7 +295,7 @@ impl Compilation {
self.named_bindings.push((
format!("scalars_{}", elem),
Binding {
item: Item::Scalar(elem),
item: Item::new(elem),
visibility: Visibility::Read,
location: Location::Storage,
size: Some(size),
@ -440,11 +435,9 @@ impl Compilation {
}
fn bool_item(ty: Item) -> Item {
match ty {
Item::Vec4(elem) => Item::Vec4(bool_elem(elem)),
Item::Vec3(elem) => Item::Vec3(bool_elem(elem)),
Item::Vec2(elem) => Item::Vec2(bool_elem(elem)),
Item::Scalar(elem) => Item::Scalar(bool_elem(elem)),
Item {
elem: bool_elem(ty.elem),
vectorization: ty.vectorization,
}
}

View File

@ -94,7 +94,7 @@ impl RangeLoop {
func: F,
) {
let mut scope = parent_scope.child();
let index_ty = Item::Scalar(Elem::UInt);
let index_ty = Item::new(Elem::UInt);
let i = scope.create_local_undeclared(index_ty);
func(i, &mut scope);

View File

@ -1,4 +1,7 @@
use crate::codegen::dialect::{macros::cpa, Item, Scope, Variable, Vectorization};
use crate::{
branch::range,
codegen::dialect::{macros::cpa, Scope, Variable, Vectorization},
};
use serde::{Deserialize, Serialize};
/// Assign value to a variable based on a given condition.
@ -19,14 +22,15 @@ impl ConditionalAssign {
let rhs = self.rhs;
let out = self.out;
let index_var = |scope: &mut Scope, var: Variable, index: usize| match var.item() {
Item::Scalar(_) => var,
_ => {
let out = scope.create_local(var.item().elem());
cpa!(scope, out = var[index]);
out
}
};
let index_var =
|scope: &mut Scope, var: Variable, index: usize| match var.item().vectorization == 1 {
true => var,
false => {
let out = scope.create_local(var.item().elem());
cpa!(scope, out = var[index]);
out
}
};
let mut assign_index = |index: usize| {
let cond = index_var(scope, cond, index);
@ -42,29 +46,20 @@ impl ConditionalAssign {
}));
};
match out.item() {
Item::Vec4(_) => {
assign_index(0);
assign_index(1);
assign_index(2);
assign_index(3);
}
Item::Vec3(_) => {
assign_index(0);
assign_index(1);
assign_index(2);
}
Item::Vec2(_) => {
assign_index(0);
assign_index(1);
}
Item::Scalar(_) => {
let vectorization = out.item().vectorization;
match vectorization == 1 {
true => {
cpa!(scope, if (cond).then(|scope| {
cpa!(scope, out = lhs);
}).else(|scope| {
cpa!(scope, out = rhs);
}));
}
false => {
for i in range(0u32, vectorization as u32, true) {
assign_index(i);
}
}
};
}

View File

@ -19,8 +19,8 @@ impl CheckedIndex {
let lhs = self.lhs;
let rhs = self.rhs;
let out = self.out;
let array_len = scope.create_local(Item::Scalar(crate::dialect::Elem::UInt));
let inside_bound = scope.create_local(Item::Scalar(crate::dialect::Elem::Bool));
let array_len = scope.create_local(Item::new(crate::dialect::Elem::UInt));
let inside_bound = scope.create_local(Item::new(crate::dialect::Elem::Bool));
cpa!(scope, array_len = len(lhs));
cpa!(scope, inside_bound = rhs < array_len);
@ -56,8 +56,8 @@ impl CheckedIndexAssign {
let lhs = self.lhs;
let rhs = self.rhs;
let out = self.out;
let array_len = scope.create_local(Item::Scalar(Elem::UInt));
let inside_bound = scope.create_local(Item::Scalar(Elem::Bool));
let array_len = scope.create_local(Item::new(Elem::UInt));
let inside_bound = scope.create_local(Item::new(Elem::Bool));
cpa!(scope, array_len = len(out));
cpa!(scope, inside_bound = lhs < array_len);

View File

@ -140,17 +140,11 @@ impl IndexOffsetGlobalWithLayout {
#[allow(missing_docs)]
pub fn expand(self, scope: &mut Scope) {
let layout = self.layout;
let index_item_ty = Item::Scalar(Elem::UInt);
let index_item_ty = Item::new(Elem::UInt);
let offset_ref = self.position;
let zero: Variable = 0u32.into();
let vectorization_factor: Variable = match self.tensors[0].item() {
Item::Vec4(_) => 4u32,
Item::Vec3(_) => 3u32,
Item::Vec2(_) => 2u32,
Item::Scalar(_) => 1u32,
}
.into();
let vectorization_factor: u8 = self.tensors[0].item().vectorization;
let vectorization_factor: Variable = (vectorization_factor as u32).into();
for index in self.indexes.iter() {
cpa!(scope, index = zero);
}

View File

@ -336,11 +336,9 @@ impl Scope {
position: Variable,
) -> Variable {
let item_global = match item.elem() {
Elem::Bool => match item {
Item::Vec4(_) => Item::Vec4(Elem::UInt),
Item::Vec3(_) => Item::Vec3(Elem::UInt),
Item::Vec2(_) => Item::Vec2(Elem::UInt),
Item::Scalar(_) => Item::Scalar(Elem::UInt),
Elem::Bool => Item {
elem: Elem::UInt,
vectorization: item.vectorization,
},
_ => item,
};

View File

@ -1,4 +1,4 @@
use super::Scope;
use super::{Scope, Vectorization};
use crate::WORKGROUP_DEFAULT;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
@ -44,7 +44,7 @@ pub enum Elem {
impl From<Elem> for Item {
fn from(val: Elem) -> Self {
Item::Scalar(val)
Item::new(val)
}
}
@ -81,22 +81,30 @@ impl Display for Elem {
}
#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize, Hash)]
#[allow(missing_docs)]
pub enum Item {
Vec4(Elem),
Vec3(Elem),
Vec2(Elem),
Scalar(Elem),
pub struct Item {
pub elem: Elem,
pub vectorization: Vectorization,
}
impl Item {
/// Fetch the elem of the item.
pub fn elem(&self) -> Elem {
match self {
Self::Vec4(elem) => *elem,
Self::Vec3(elem) => *elem,
Self::Vec2(elem) => *elem,
Self::Scalar(elem) => *elem,
self.elem
}
/// Create a new item without vectorization
pub fn new(elem: Elem) -> Self {
Self {
elem,
vectorization: 1,
}
}
/// Create a new item with vectorization
pub fn vectorized(elem: Elem, vectorization: Vectorization) -> Self {
Self {
elem,
vectorization,
}
}
}

View File

@ -69,30 +69,30 @@ impl Variable {
match self {
Variable::GlobalInputArray(_, item) => *item,
Variable::GlobalOutputArray(_, item) => *item,
Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
Variable::GlobalScalar(_, elem) => Item::new(*elem),
Variable::Local(_, item, _) => *item,
Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem),
Variable::ConstantScalar(_, elem) => Item::Scalar(*elem),
Variable::LocalScalar(_, elem, _) => Item::new(*elem),
Variable::ConstantScalar(_, elem) => Item::new(*elem),
Variable::SharedMemory(_, item, _) => *item,
Variable::LocalArray(_, item, _, _) => *item,
Variable::Id => Item::Scalar(Elem::UInt),
Variable::Rank => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIdX => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIdY => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIdZ => Item::Scalar(Elem::UInt),
Variable::WorkgroupIdX => Item::Scalar(Elem::UInt),
Variable::WorkgroupIdY => Item::Scalar(Elem::UInt),
Variable::WorkgroupIdZ => Item::Scalar(Elem::UInt),
Variable::GlobalInvocationIdX => Item::Scalar(Elem::UInt),
Variable::GlobalInvocationIdY => Item::Scalar(Elem::UInt),
Variable::GlobalInvocationIdZ => Item::Scalar(Elem::UInt),
Variable::WorkgroupSizeX => Item::Scalar(Elem::UInt),
Variable::WorkgroupSizeY => Item::Scalar(Elem::UInt),
Variable::WorkgroupSizeZ => Item::Scalar(Elem::UInt),
Variable::NumWorkgroupsX => Item::Scalar(Elem::UInt),
Variable::NumWorkgroupsY => Item::Scalar(Elem::UInt),
Variable::NumWorkgroupsZ => Item::Scalar(Elem::UInt),
Variable::Id => Item::new(Elem::UInt),
Variable::Rank => Item::new(Elem::UInt),
Variable::LocalInvocationIndex => Item::new(Elem::UInt),
Variable::LocalInvocationIdX => Item::new(Elem::UInt),
Variable::LocalInvocationIdY => Item::new(Elem::UInt),
Variable::LocalInvocationIdZ => Item::new(Elem::UInt),
Variable::WorkgroupIdX => Item::new(Elem::UInt),
Variable::WorkgroupIdY => Item::new(Elem::UInt),
Variable::WorkgroupIdZ => Item::new(Elem::UInt),
Variable::GlobalInvocationIdX => Item::new(Elem::UInt),
Variable::GlobalInvocationIdY => Item::new(Elem::UInt),
Variable::GlobalInvocationIdZ => Item::new(Elem::UInt),
Variable::WorkgroupSizeX => Item::new(Elem::UInt),
Variable::WorkgroupSizeY => Item::new(Elem::UInt),
Variable::WorkgroupSizeZ => Item::new(Elem::UInt),
Variable::NumWorkgroupsX => Item::new(Elem::UInt),
Variable::NumWorkgroupsY => Item::new(Elem::UInt),
Variable::NumWorkgroupsZ => Item::new(Elem::UInt),
}
}
}

View File

@ -1,19 +1,6 @@
use super::{BinaryOperator, ClampOperator, Item, Operation, Operator, UnaryOperator, Variable};
/// Define a vectorization scheme.
#[allow(dead_code)]
#[derive(Copy, Clone, Debug, Default, Hash)]
pub enum Vectorization {
/// Use vec4 for vectorization.
Vec4,
/// Use vec3 for vectorization.
Vec3,
/// Use vec2 for vectorization.
Vec2,
/// Don't vectorize.
#[default]
Scalar,
}
pub type Vectorization = u8;
impl Operation {
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
@ -169,21 +156,14 @@ impl Variable {
}
impl Item {
pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Item {
match vectorize {
Vectorization::Vec4 => Item::Vec4(self.elem()),
Vectorization::Vec3 => Item::Vec3(self.elem()),
Vectorization::Vec2 => Item::Vec2(self.elem()),
Vectorization::Scalar => Item::Scalar(self.elem()),
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Item {
Item {
elem: self.elem,
vectorization,
}
}
pub(crate) fn vectorized_size(&self, vectorize: Vectorization, size: u32) -> u32 {
match vectorize {
Vectorization::Vec4 => size / 4,
Vectorization::Vec3 => size / 3,
Vectorization::Vec2 => size / 2,
Vectorization::Scalar => size,
}
size / (vectorize as u32)
}
}

View File

@ -41,7 +41,7 @@ pub fn range_expand<F>(
}
} else {
let mut child = context.child();
let index_ty = Item::Scalar(Elem::UInt);
let index_ty = Item::new(Elem::UInt);
let i = child.scope.borrow_mut().create_local_undeclared(index_ty);
let i = ExpandElement::new(Rc::new(i));

View File

@ -1,4 +1,4 @@
use crate::dialect::Elem;
use crate::dialect::{Elem, Vectorization};
use crate::language::{CubeContext, CubeType, ExpandElement, PrimitiveVariable};
@ -34,10 +34,14 @@ impl Bool {
impl PrimitiveVariable for Bool {
type Primitive = bool;
fn into_elem() -> Elem {
fn as_elem() -> Elem {
Elem::Bool
}
fn vectorization(&self) -> Vectorization {
self.vectorization
}
fn to_f64(&self) -> f64 {
match self.val {
true => 1.,
@ -52,4 +56,13 @@ impl PrimitiveVariable for Bool {
fn from_i64(val: i64) -> Self {
Self::from_f64(val as f64)
}
fn from_i64_vec(vec: &[i64]) -> Self {
Self {
// We take only one value, because type implements copy and we can't copy an unknown sized vec
// For debugging prefer unvectorized types
val: *vec.first().expect("Should be at least one value") > 0,
vectorization: vec.len() as u8,
}
}
}

View File

@ -8,7 +8,7 @@ pub trait Cast: PrimitiveVariable {
context: &mut CubeContext,
value: <Self as CubeType>::ExpandType,
) -> <Self as CubeType>::ExpandType {
let new_var = context.create_local(Item::Scalar(<Self as PrimitiveVariable>::into_elem()));
let new_var = context.create_local(Item::new(<Self as PrimitiveVariable>::as_elem()));
assign::expand(context, value, new_var.clone());
new_var
}

View File

@ -1,4 +1,4 @@
use crate::dialect::{Elem, FloatKind, Variable};
use crate::dialect::{Elem, FloatKind, Variable, Vectorization};
use crate::language::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable};
use std::rc::Rc;
@ -13,7 +13,7 @@ macro_rules! impl_float {
#[derive(Clone, Copy)]
pub struct $type {
pub val: <Self as PrimitiveVariable>::Primitive,
pub vectorization: usize,
pub vectorization: u8,
}
impl CubeType for $type {
@ -24,10 +24,14 @@ macro_rules! impl_float {
type Primitive = f64;
/// Return the element type to use on GPU
fn into_elem() -> Elem {
fn as_elem() -> Elem {
Elem::Float(FloatKind::$type)
}
fn vectorization(&self) -> Vectorization {
self.vectorization.into()
}
fn to_f64(&self) -> f64 {
self.val
}
@ -39,6 +43,16 @@ macro_rules! impl_float {
fn from_i64(val: i64) -> Self {
Self::new(val as f64)
}
fn from_i64_vec(vec: &[i64]) -> Self {
Self {
// We take only one value, because type implements copy and we can't copy an unknown sized vec
// When using CPU-side values for debugging kernels, prefer using unvectorized types
val: *vec.first().expect("Should be at least one value")
as <Self as PrimitiveVariable>::Primitive,
vectorization: vec.len() as u8,
}
}
}
impl Numeric for $type {}
@ -55,7 +69,7 @@ macro_rules! impl_float {
_context: &mut CubeContext,
val: <Self as PrimitiveVariable>::Primitive,
) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar(val as f64, Self::into_elem());
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
ExpandElement::new(Rc::new(new_var))
}
}

View File

@ -1,4 +1,4 @@
use crate::dialect::{Elem, IntKind, Variable};
use crate::dialect::{Elem, IntKind, Variable, Vectorization};
use crate::language::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable};
use std::rc::Rc;
@ -13,7 +13,7 @@ macro_rules! impl_int {
#[derive(Clone, Copy)]
pub struct $type {
pub val: <Self as PrimitiveVariable>::Primitive,
pub vectorization: usize,
pub vectorization: u8,
}
impl CubeType for $type {
@ -23,10 +23,14 @@ macro_rules! impl_int {
impl PrimitiveVariable for $type {
type Primitive = i64;
fn into_elem() -> Elem {
fn as_elem() -> Elem {
Elem::Int(IntKind::$type)
}
fn vectorization(&self) -> Vectorization {
self.vectorization.into()
}
fn to_f64(&self) -> f64 {
self.val as f64
}
@ -38,6 +42,16 @@ macro_rules! impl_int {
fn from_i64(val: i64) -> Self {
Self::new(val)
}
fn from_i64_vec(vec: &[i64]) -> Self {
Self {
// We take only one value, because type implements copy and we can't copy an unknown sized vec
// For debugging prefer unvectorized types
val: *vec.first().expect("Should be at least one value")
as <Self as PrimitiveVariable>::Primitive,
vectorization: vec.len() as u8,
}
}
}
impl Numeric for $type {}
@ -53,7 +67,7 @@ macro_rules! impl_int {
_context: &mut CubeContext,
val: <Self as PrimitiveVariable>::Primitive,
) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar(val as f64, Self::into_elem());
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
ExpandElement::new(Rc::new(new_var))
}
}

View File

@ -1,4 +1,5 @@
use crate::dialect::Variable;
use crate::dialect::{Item, Variable};
use crate::index_assign;
use crate::language::{CubeContext, CubeType, ExpandElement, PrimitiveVariable};
use std::rc::Rc;
@ -24,9 +25,25 @@ pub trait Numeric:
<Self as PrimitiveVariable>::from_i64(val)
}
/// Expand version of lit
/// Expand version of from_int
fn from_int_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar(val as f64, Self::into_elem());
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
ExpandElement::new(Rc::new(new_var))
}
fn from_vec(vec: &[i64]) -> Self {
<Self as PrimitiveVariable>::from_i64_vec(vec)
}
fn from_vec_expand(context: &mut CubeContext, vec: &[i64]) -> <Self as CubeType>::ExpandType {
let mut new_var = context.create_local(Item {
elem: Self::as_elem(),
vectorization: (vec.len() as u8),
});
for (i, element) in vec.iter().enumerate() {
new_var = index_assign::expand(context, new_var, i.into(), (*element).into());
}
new_var
}
}

View File

@ -1,6 +1,6 @@
use std::rc::Rc;
use crate::dialect::{Elem, Variable};
use crate::dialect::{Elem, Variable, Vectorization};
use crate::language::{CubeType, ExpandElement};
/// Form of CubeType that encapsulates all primitive types:
@ -9,12 +9,15 @@ pub trait PrimitiveVariable: CubeType<ExpandType = ExpandElement> {
type Primitive;
/// Return the element type to use on GPU
fn into_elem() -> Elem;
fn as_elem() -> Elem;
fn vectorization(&self) -> Vectorization;
// For easy CPU-side casting
fn to_f64(&self) -> f64;
fn from_f64(val: f64) -> Self;
fn from_i64(val: i64) -> Self;
fn from_i64_vec(vec: &[i64]) -> Self;
}
macro_rules! impl_into_expand_element {

View File

@ -1,6 +1,6 @@
use std::rc::Rc;
use crate::dialect::{Elem, Variable};
use crate::dialect::{Elem, Variable, Vectorization};
use crate::language::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable};
#[derive(Clone, Copy)]
@ -18,10 +18,14 @@ impl CubeType for UInt {
impl PrimitiveVariable for UInt {
type Primitive = u32;
fn into_elem() -> Elem {
fn as_elem() -> Elem {
Elem::UInt
}
fn vectorization(&self) -> Vectorization {
self.vectorization
}
fn to_f64(&self) -> f64 {
self.val as f64
}
@ -33,6 +37,16 @@ impl PrimitiveVariable for UInt {
fn from_i64(val: i64) -> Self {
Self::new(val as <Self as PrimitiveVariable>::Primitive)
}
fn from_i64_vec(vec: &[i64]) -> Self {
Self {
// We take only one value, because type implements copy and we can't copy an unknown sized vec
// For debugging prefer unvectorized types
val: *vec.first().expect("Should be at least one value")
as <Self as PrimitiveVariable>::Primitive,
vectorization: vec.len() as u8,
}
}
}
impl Numeric for UInt {}
@ -48,7 +62,7 @@ impl UInt {
_context: &mut CubeContext,
val: <Self as PrimitiveVariable>::Primitive,
) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar(val as f64, Self::into_elem());
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
ExpandElement::new(Rc::new(new_var))
}
}

View File

@ -26,12 +26,13 @@ pub mod index_assign {
array: ExpandElement,
index: ExpandElement,
value: ExpandElement,
) {
) -> ExpandElement {
context.register(Operator::IndexAssign(BinaryOperator {
lhs: *index,
rhs: *value,
out: *array,
}))
}));
array
}
impl<E: CubeType, I: Into<UInt>> core::ops::IndexMut<I> for Array<E> {

View File

@ -1,4 +1,4 @@
use crate::dialect::{BinaryOperator, Elem, Item, Operator, Variable};
use crate::dialect::{BinaryOperator, Elem, Item, Operator, Variable, Vectorization};
use crate::language::{CubeContext, ExpandElement};
pub(crate) fn binary_expand<F>(
@ -14,6 +14,8 @@ where
let rhs: Variable = *rhs;
let item = lhs.item();
check_vectorization(item.vectorization, rhs.item().vectorization);
let out = context.create_local(item);
let out_var = *out;
@ -39,13 +41,15 @@ where
{
let lhs: Variable = *lhs;
let rhs: Variable = *rhs;
let item = lhs.item();
let out_item = match lhs.item() {
Item::Vec4(_) => Item::Vec4(Elem::Bool),
Item::Vec3(_) => Item::Vec3(Elem::Bool),
Item::Vec2(_) => Item::Vec2(Elem::Bool),
Item::Scalar(_) => Item::Scalar(Elem::Bool),
check_vectorization(item.vectorization, rhs.item().vectorization);
let out_item = Item {
elem: Elem::Bool,
vectorization: item.vectorization,
};
let out = context.create_local(out_item);
let out_var = *out;
@ -82,3 +86,13 @@ where
lhs
}
fn check_vectorization(lhs: Vectorization, rhs: Vectorization) {
if lhs == 1 || rhs == 1 {
return;
}
assert!(
lhs == rhs,
"Tried to perform binary operation on different vectorization schemes."
);
}

View File

@ -151,109 +151,109 @@ mod tests {
cast_test!(
cube_float_to_float_test,
float_to_float_expand,
Item::Scalar(F32::into_elem())
Item::new(F32::as_elem())
);
cast_test!(
cube_float_to_int_test,
float_to_int_expand,
Item::Scalar(F32::into_elem()),
Item::Scalar(I32::into_elem())
Item::new(F32::as_elem()),
Item::new(I32::as_elem())
);
cast_test!(
cube_float_to_uint_test,
float_to_uint_expand,
Item::Scalar(F32::into_elem()),
Item::Scalar(Elem::UInt)
Item::new(F32::as_elem()),
Item::new(Elem::UInt)
);
cast_test!(
cube_float_to_bool_test,
float_to_bool_expand,
Item::Scalar(F32::into_elem()),
Item::Scalar(Elem::Bool)
Item::new(F32::as_elem()),
Item::new(Elem::Bool)
);
cast_test!(
cube_int_to_float_test,
int_to_float_expand,
Item::Scalar(I32::into_elem()),
Item::Scalar(F32::into_elem())
Item::new(I32::as_elem()),
Item::new(F32::as_elem())
);
cast_test!(
cube_int_to_int_test,
int_to_int_expand,
Item::Scalar(I32::into_elem())
Item::new(I32::as_elem())
);
cast_test!(
cube_int_to_uint_test,
int_to_uint_expand,
Item::Scalar(I32::into_elem()),
Item::Scalar(Elem::UInt)
Item::new(I32::as_elem()),
Item::new(Elem::UInt)
);
cast_test!(
cube_int_to_bool_test,
int_to_bool_expand,
Item::Scalar(I32::into_elem()),
Item::Scalar(Elem::Bool)
Item::new(I32::as_elem()),
Item::new(Elem::Bool)
);
cast_test!(
cube_uint_to_float_test,
uint_to_float_expand,
Item::Scalar(Elem::UInt),
Item::Scalar(F32::into_elem())
Item::new(Elem::UInt),
Item::new(F32::as_elem())
);
cast_test!(
cube_uint_to_int_test,
uint_to_int_expand,
Item::Scalar(Elem::UInt),
Item::Scalar(I32::into_elem())
Item::new(Elem::UInt),
Item::new(I32::as_elem())
);
cast_test!(
cube_uint_to_uint_test,
uint_to_uint_expand,
Item::Scalar(Elem::UInt)
Item::new(Elem::UInt)
);
cast_test!(
cube_uint_to_bool_test,
uint_to_bool_expand,
Item::Scalar(Elem::UInt),
Item::Scalar(Elem::Bool)
Item::new(Elem::UInt),
Item::new(Elem::Bool)
);
cast_test!(
cube_bool_to_float_test,
bool_to_float_expand,
Item::Scalar(Elem::Bool),
Item::Scalar(F32::into_elem())
Item::new(Elem::Bool),
Item::new(F32::as_elem())
);
cast_test!(
cube_bool_to_int_test,
bool_to_int_expand,
Item::Scalar(Elem::Bool),
Item::Scalar(I32::into_elem())
Item::new(Elem::Bool),
Item::new(I32::as_elem())
);
cast_test!(
cube_bool_to_uint_test,
bool_to_uint_expand,
Item::Scalar(Elem::Bool),
Item::Scalar(Elem::UInt)
Item::new(Elem::Bool),
Item::new(Elem::UInt)
);
cast_test!(
cube_bool_to_bool_test,
bool_to_bool_expand,
Item::Scalar(Elem::Bool)
Item::new(Elem::Bool)
);
fn inline_macro_ref_cast(from_item: Item, to_item: Item) -> String {

View File

@ -35,7 +35,7 @@ mod tests {
#[test]
fn cube_cast_float_kind_test() {
let mut context = CubeContext::root();
let item = Item::Scalar(F64::into_elem());
let item = Item::new(F64::as_elem());
let input = context.create_local(item);
@ -48,7 +48,7 @@ mod tests {
#[test]
fn cube_cast_int_kind_test() {
let mut context = CubeContext::root();
let item = Item::Scalar(I32::into_elem());
let item = Item::new(I32::as_elem());
let input = context.create_local(item);
@ -61,7 +61,7 @@ mod tests {
#[test]
fn cube_cast_numeric_kind_test() {
let mut context = CubeContext::root();
let item = Item::Scalar(I32::into_elem());
let item = Item::new(I32::as_elem());
let input = context.create_local(item);
@ -74,7 +74,7 @@ mod tests {
#[test]
fn cube_cast_kind_numeric_test() {
let mut context = CubeContext::root();
let item = Item::Scalar(I32::into_elem());
let item = Item::new(I32::as_elem());
let input = context.create_local(item);
@ -86,8 +86,8 @@ mod tests {
fn inline_macro_ref_float() -> String {
let mut context = CubeContext::root();
let float_64 = Item::Scalar(F64::into_elem());
let float_32 = Item::Scalar(F32::into_elem());
let float_64 = Item::new(F64::as_elem());
let float_32 = Item::new(F32::as_elem());
let input = context.create_local(float_64);
let mut scope = context.into_scope();
@ -104,8 +104,8 @@ mod tests {
fn inline_macro_ref_int() -> String {
let mut context = CubeContext::root();
let int_32 = Item::Scalar(I32::into_elem());
let int_64 = Item::Scalar(I64::into_elem());
let int_32 = Item::new(I32::as_elem());
let int_64 = Item::new(I64::as_elem());
let input = context.create_local(int_32);
let mut scope = context.into_scope();

View File

@ -25,8 +25,8 @@ mod tests {
let mut context = CubeContext::root();
let unroll = true;
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
let rhs = context.create_local(Item::Scalar(ElemType::into_elem()));
let lhs = context.create_local(Item::new(ElemType::as_elem()));
let rhs = context.create_local(Item::new(ElemType::as_elem()));
let end = 4u32.into();
for_loop_expand::<ElemType>(&mut context, lhs, rhs, end, unroll);
@ -40,8 +40,8 @@ mod tests {
let mut context = CubeContext::root();
let unroll = false;
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
let rhs = context.create_local(Item::Scalar(ElemType::into_elem()));
let lhs = context.create_local(Item::new(ElemType::as_elem()));
let rhs = context.create_local(Item::new(ElemType::as_elem()));
let end = 4u32.into();
for_loop_expand::<ElemType>(&mut context, lhs, rhs, end, unroll);
@ -52,7 +52,7 @@ mod tests {
fn inline_macro_ref(unroll: bool) -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(ElemType::into_elem());
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);
let rhs = context.create_local(item);

View File

@ -55,12 +55,12 @@ mod tests {
#[test]
fn cube_call_equivalent_to_no_call_no_arg_test() {
let mut caller_context = CubeContext::root();
let x = caller_context.create_local(Item::Scalar(Elem::UInt));
let x = caller_context.create_local(Item::new(Elem::UInt));
caller_no_arg_expand(&mut caller_context, x);
let caller_scope = caller_context.into_scope();
let mut no_call_context = CubeContext::root();
let x = no_call_context.create_local(Item::Scalar(Elem::UInt));
let x = no_call_context.create_local(Item::new(Elem::UInt));
no_call_no_arg_expand(&mut no_call_context, x);
let no_call_scope = no_call_context.into_scope();
@ -74,12 +74,12 @@ mod tests {
fn cube_call_equivalent_to_no_call_with_arg_test() {
let mut caller_context = CubeContext::root();
let x = caller_context.create_local(Item::Scalar(Elem::UInt));
let x = caller_context.create_local(Item::new(Elem::UInt));
caller_with_arg_expand(&mut caller_context, x);
let caller_scope = caller_context.into_scope();
let mut no_call_context = CubeContext::root();
let x = no_call_context.create_local(Item::Scalar(Elem::UInt));
let x = no_call_context.create_local(Item::new(Elem::UInt));
no_call_with_arg_expand(&mut no_call_context, x);
let no_call_scope = no_call_context.into_scope();
@ -93,12 +93,12 @@ mod tests {
fn cube_call_equivalent_to_no_call_with_generics_test() {
let mut caller_context = CubeContext::root();
type ElemType = I64;
let x = caller_context.create_local(Item::Scalar(ElemType::into_elem()));
let x = caller_context.create_local(Item::new(ElemType::as_elem()));
caller_with_generics_expand::<ElemType>(&mut caller_context, x);
let caller_scope = caller_context.into_scope();
let mut no_call_context = CubeContext::root();
let x = no_call_context.create_local(Item::Scalar(ElemType::into_elem()));
let x = no_call_context.create_local(Item::new(ElemType::as_elem()));
no_call_with_generics_expand::<ElemType>(&mut no_call_context, x);
let no_call_scope = no_call_context.into_scope();

View File

@ -14,7 +14,7 @@ mod tests {
fn cube_generic_float_test() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::Scalar(F32::into_elem()));
let lhs = context.create_local(Item::new(F32::as_elem()));
generic_kernel_expand::<F32>(&mut context, lhs);
let scope = context.into_scope();
@ -26,7 +26,7 @@ mod tests {
fn cube_generic_int_test() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::Scalar(I32::into_elem()));
let lhs = context.create_local(Item::new(I32::as_elem()));
generic_kernel_expand::<I32>(&mut context, lhs);
let scope = context.into_scope();
@ -36,7 +36,7 @@ mod tests {
fn inline_macro_ref_float() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(F32::into_elem());
let item = Item::new(F32::as_elem());
let lhs = context.create_local(item);
let mut scope = context.into_scope();
@ -48,7 +48,7 @@ mod tests {
fn inline_macro_ref_int() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(I32::into_elem());
let item = Item::new(I32::as_elem());
let lhs = context.create_local(item);
let mut scope = context.into_scope();

View File

@ -22,7 +22,7 @@ mod tests {
fn cube_if_test() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
let lhs = context.create_local(Item::new(ElemType::as_elem()));
if_greater_expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
@ -32,11 +32,11 @@ mod tests {
fn inline_macro_ref() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(ElemType::into_elem());
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);
let mut scope = context.into_scope();
let cond = scope.create_local(Item::Scalar(Elem::Bool));
let cond = scope.create_local(Item::new(Elem::Bool));
let lhs: Variable = lhs.into();
let y = scope.create_local(item);

View File

@ -24,7 +24,7 @@ mod tests {
fn cube_if_else_test() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
let lhs = context.create_local(Item::new(ElemType::as_elem()));
if_then_else_expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
@ -34,11 +34,11 @@ mod tests {
fn inline_macro_ref() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(ElemType::into_elem());
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);
let mut scope = context.into_scope();
let cond = scope.create_local(Item::Scalar(Elem::Bool));
let cond = scope.create_local(Item::new(Elem::Bool));
let lhs: Variable = lhs.into();
let y = scope.create_local(item);

View File

@ -20,7 +20,7 @@ mod tests {
fn cube_literal_test() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
let lhs = context.create_local(Item::new(ElemType::as_elem()));
literal_expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
@ -32,7 +32,7 @@ mod tests {
fn cube_literal_float_no_decimal_test() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
let lhs = context.create_local(Item::new(ElemType::as_elem()));
literal_float_no_decimals_expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
@ -42,7 +42,7 @@ mod tests {
fn inline_macro_ref() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(ElemType::into_elem());
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);
let mut scope = context.into_scope();

View File

@ -32,7 +32,7 @@ mod tests {
fn cube_while_test() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
let lhs = context.create_local(Item::new(ElemType::as_elem()));
while_not_expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
@ -44,7 +44,7 @@ mod tests {
fn cube_loop_break_test() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
let lhs = context.create_local(Item::new(ElemType::as_elem()));
manual_loop_break_expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
@ -54,11 +54,11 @@ mod tests {
fn inline_macro_ref() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(ElemType::into_elem());
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);
let mut scope = context.into_scope();
let cond = scope.create_local(Item::Scalar(Elem::Bool));
let cond = scope.create_local(Item::new(Elem::Bool));
let lhs: Variable = lhs.into();
let rhs = scope.create_local(item);

View File

@ -11,3 +11,4 @@ mod module_import;
mod parenthesis;
mod reuse;
mod r#trait;
mod vectorization;

View File

@ -33,12 +33,12 @@ mod tests {
#[test]
fn cube_call_equivalent_to_no_call_no_arg_test() {
let mut caller_context = CubeContext::root();
let x = caller_context.create_local(Item::Scalar(ElemType::into_elem()));
let x = caller_context.create_local(Item::new(ElemType::as_elem()));
here::caller_expand::<ElemType>(&mut caller_context, x);
let caller_scope = caller_context.into_scope();
let mut no_call_context = CubeContext::root();
let x = no_call_context.create_local(Item::Scalar(ElemType::into_elem()));
let x = no_call_context.create_local(Item::new(ElemType::as_elem()));
here::no_call_ref_expand::<ElemType>(&mut no_call_context, x);
let no_call_scope = no_call_context.into_scope();

View File

@ -20,9 +20,9 @@ mod tests {
fn cube_parenthesis_priority_test() {
let mut context = CubeContext::root();
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
let y = context.create_local(Item::Scalar(ElemType::into_elem()));
let z = context.create_local(Item::Scalar(ElemType::into_elem()));
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
let z = context.create_local(Item::new(ElemType::as_elem()));
parenthesis_expand::<ElemType>(&mut context, x, y, z);
let scope = context.into_scope();
@ -32,7 +32,7 @@ mod tests {
fn inline_macro_ref() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(ElemType::into_elem());
let item = Item::new(ElemType::as_elem());
let x = context.create_local(item);
let y = context.create_local(item);
let z = context.create_local(item);

View File

@ -32,7 +32,7 @@ mod tests {
fn cube_reuse_assign_test() {
let mut context = CubeContext::root();
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
let x = context.create_local(Item::new(ElemType::as_elem()));
reuse_expand::<ElemType>(&mut context, x);
let scope = context.into_scope();
@ -44,7 +44,7 @@ mod tests {
fn cube_reuse_incr_test() {
let mut context = CubeContext::root();
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
let x = context.create_local(Item::new(ElemType::as_elem()));
reuse_incr_expand::<ElemType>(&mut context, x);
let scope = context.into_scope();
@ -54,11 +54,11 @@ mod tests {
fn inline_macro_ref_assign() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(ElemType::into_elem());
let item = Item::new(ElemType::as_elem());
let x = context.create_local(item);
let mut scope = context.into_scope();
let cond = scope.create_local(Item::Scalar(Elem::Bool));
let cond = scope.create_local(Item::new(Elem::Bool));
let x: Variable = x.into();
let tmp = scope.create_local(item);
@ -80,11 +80,11 @@ mod tests {
fn inline_macro_ref_incr() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(ElemType::into_elem());
let item = Item::new(ElemType::as_elem());
let x = context.create_local(item);
let mut scope = context.into_scope();
let cond = scope.create_local(Item::Scalar(Elem::Bool));
let cond = scope.create_local(Item::new(Elem::Bool));
let x: Variable = x.into();
cpa!(

View File

@ -114,8 +114,8 @@ mod tests {
fn cube_strategy_trait_add_test() {
let mut context = CubeContext::root();
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
let y = context.create_local(Item::Scalar(ElemType::into_elem()));
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
with_strategy_trait_expand::<AddStrategy, ElemType>(&mut context, x, y);
let scope = context.into_scope();
@ -130,8 +130,8 @@ mod tests {
fn cube_strategy_trait_sub_test() {
let mut context = CubeContext::root();
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
let y = context.create_local(Item::Scalar(ElemType::into_elem()));
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
with_strategy_trait_expand::<SubStrategy, ElemType>(&mut context, x, y);
let scope = context.into_scope();
@ -146,8 +146,8 @@ mod tests {
fn cube_two_strategy_traits_test() {
let mut context = CubeContext::root();
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
let y = context.create_local(Item::Scalar(ElemType::into_elem()));
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
two_strategy_traits_expand::<SubStrategy, AddStrategy, ElemType>(&mut context, x, y);
let scope = context.into_scope();
@ -159,8 +159,8 @@ mod tests {
fn cube_trait_generic_method_test() {
let mut context = CubeContext::root();
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
let y = context.create_local(Item::Scalar(ElemType::into_elem()));
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
with_trait_generic_method_expand::<AddStrategy, ElemType>(&mut context, x, y);
let scope = context.into_scope();
@ -173,7 +173,7 @@ mod tests {
fn inline_macro_ref_one(is_add_strategy: bool) -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(ElemType::into_elem());
let item = Item::new(ElemType::as_elem());
let x = context.create_local(item);
let y = context.create_local(item);
@ -192,7 +192,7 @@ mod tests {
fn inline_macro_ref_two() -> String {
let mut context = CubeContext::root();
let item = Item::Scalar(ElemType::into_elem());
let item = Item::new(ElemType::as_elem());
let x = context.create_local(item);
let y = context.create_local(item);

View File

@ -0,0 +1,67 @@
use burn_cube::{cube, Numeric};
#[cube]
pub fn vectorization_binary<T: Numeric>(lhs: T) {
let _ = lhs + T::from_vec(&[4, 5]);
}
#[cube]
pub fn vectorization_cmp<T: Numeric>(rhs: T) {
let _ = T::from_vec(&[4, 5]) > rhs;
}
mod tests {
use burn_cube::{dialect::Item, CubeContext, PrimitiveVariable, F32};
use crate::language::vectorization::{vectorization_binary_expand, vectorization_cmp_expand};
type ElemType = F32;
#[test]
fn cube_vectorization_binary_op_with_same_scheme_does_not_fail() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2));
vectorization_binary_expand::<ElemType>(&mut context, lhs);
}
#[test]
#[should_panic]
fn cube_vectorization_binary_op_with_different_scheme_fails() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4));
vectorization_binary_expand::<ElemType>(&mut context, lhs);
}
#[test]
fn cube_vectorization_cmp_op_with_same_scheme_does_not_fail() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2));
vectorization_cmp_expand::<ElemType>(&mut context, lhs);
}
#[test]
#[should_panic]
fn cube_vectorization_cmp_op_with_different_scheme_fails() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4));
vectorization_cmp_expand::<ElemType>(&mut context, lhs);
}
#[test]
fn cube_vectorization_can_be_broadcasted() {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 1));
vectorization_cmp_expand::<ElemType>(&mut context, lhs);
}
}

View File

@ -1,7 +1,6 @@
use burn_cube::{
calculate_num_elems_dyn_rank,
dialect::{Vectorization, WorkgroupSize},
elemwise_workgroup, CompilationInfo, CompilationSettings,
calculate_num_elems_dyn_rank, dialect::WorkgroupSize, elemwise_workgroup, CompilationInfo,
CompilationSettings,
};
use burn_tensor::repr::TensorDescription;
@ -56,12 +55,10 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
);
if vectorize_4 {
settings = settings.vectorize(Vectorization::Vec4);
settings = settings.vectorize(4);
factor = 4;
}
if !vectorize_4 && vectorize_2 {
settings = settings.vectorize(Vectorization::Vec2);
} else if vectorize_2 {
settings = settings.vectorize(2);
factor = 2;
}

View File

@ -47,7 +47,7 @@ impl TraceBuilder {
false => {
// New input
let index = self.inputs.len() as u16;
let item = Item::Scalar(elem);
let item = Item::new(elem);
let local = self.scope.read_array(index, item, position);
self.inputs.push((tensor.clone(), local));
@ -56,7 +56,7 @@ impl TraceBuilder {
true => match self.output_to_local.get(&tensor.id) {
// Is a local variable.
Some(local_index) => {
Variable::Local(*local_index, Item::Scalar(elem), self.scope.depth)
Variable::Local(*local_index, Item::new(elem), self.scope.depth)
}
// Isn't an operation output variable, so must be an existing input.
None => self
@ -84,10 +84,10 @@ impl TraceBuilder {
// Output already registered as a local variable.
if let Some(index) = self.output_to_local.get(&tensor.id) {
return Variable::Local(*index, Item::Scalar(elem), self.scope.depth);
return Variable::Local(*index, Item::new(elem), self.scope.depth);
}
let variable = self.scope.create_local(Item::Scalar(elem));
let variable = self.scope.create_local(Item::new(elem));
let local_index = variable.index().unwrap();
self.output_to_local.insert(tensor.id, local_index);
variable

View File

@ -45,7 +45,7 @@ impl Trace {
.inputs
.iter()
.map(|(_tensor, elem, _)| InputInfo::Array {
item: Item::Scalar(*elem),
item: Item::new(*elem),
visibility: Visibility::Read,
})
.collect::<Vec<_>>();
@ -56,7 +56,7 @@ impl Trace {
.zip(self.locals.iter())
.map(
|((_tensor, elem, index_ref), local)| OutputInfo::ArrayWrite {
item: Item::Scalar(*elem),
item: Item::new(*elem),
local: *local,
position: *index_ref,
},

View File

@ -65,15 +65,15 @@ macro_rules! binary {
let local = scope.last_local_index().unwrap().index().unwrap();
let lhs = burn_cube::InputInfo::Array {
item: burn_cube::dialect::Item::Scalar(I::cube_elem()),
item: burn_cube::dialect::Item::new(I::cube_elem()),
visibility: burn_cube::dialect::Visibility::Read,
};
let rhs = burn_cube::InputInfo::Array {
item: burn_cube::dialect::Item::Scalar(I::cube_elem()),
item: burn_cube::dialect::Item::new(I::cube_elem()),
visibility: burn_cube::dialect::Visibility::Read,
};
let out = burn_cube::OutputInfo::ArrayWrite {
item: burn_cube::dialect::Item::Scalar(O::cube_elem()),
item: burn_cube::dialect::Item::new(O::cube_elem()),
local,
position,
};

View File

@ -56,7 +56,7 @@ pub(crate) struct BoolCastEagerKernel<R: JitRuntime, EO: JitElement> {
impl<R: JitRuntime, EO: JitElement> GpuComputeShaderPhase for BoolCastEagerKernel<R, EO> {
fn compile(&self) -> ComputeShader {
let mut scope = Scope::root();
let item_input = Item::Scalar(Elem::Bool);
let item_input = Item::new(Elem::Bool);
let item_output = EO::cube_elem().into();
let tensor = Variable::GlobalInputArray(0, item_input);

View File

@ -125,14 +125,14 @@ pub(crate) fn gather_shader_information(
// Registers used in the compute pass
let results = scope.create_local_array(elem, results_size);
let register_m = scope.create_local(Item::Vec4(elem));
let register_n = scope.create_local(Item::Vec4(elem));
let register_m = scope.create_local(Item::vectorized(elem, 4));
let register_n = scope.create_local(Item::vectorized(elem, 4));
let shared_lhs = scope.create_shared(
Item::Vec4(elem),
Item::vectorized(elem, 4),
shader.config.block_size_m as u32 * shader.config.block_size_k as u32 / 4u32,
);
let shared_rhs = scope.create_shared(
Item::Vec4(elem),
Item::vectorized(elem, 4),
shader.config.block_size_k as u32 * shader.config.block_size_n as u32 / 4u32,
);

View File

@ -268,7 +268,7 @@ impl<R: JitRuntime, E: JitElement> GpuComputeShaderPhase
let mut scope = Scope::root();
let item = E::cube_elem().into();
let indices = Variable::GlobalInputArray(0, Item::Scalar(Elem::Int(IntKind::I32)));
let indices = Variable::GlobalInputArray(0, Item::new(Elem::Int(IntKind::I32)));
let grad = Variable::GlobalInputArray(1, item);
let output = Variable::GlobalOutputArray(0, item);
@ -283,7 +283,7 @@ impl<R: JitRuntime, E: JitElement> GpuComputeShaderPhase
.expand(&mut scope);
let indices = InputInfo::Array {
item: Item::Scalar(Elem::Int(IntKind::I32)),
item: Item::new(Elem::Int(IntKind::I32)),
visibility: Visibility::Read,
};

View File

@ -184,7 +184,7 @@ impl<P: PoolStrategy, R: JitRuntime, E: JitElement> GpuComputeShaderPhase
let indices = if P::with_indices() {
Some(Variable::GlobalOutputArray(
1,
Item::Scalar(Elem::Int(IntKind::I32)),
Item::new(Elem::Int(IntKind::I32)),
))
} else {
None
@ -216,7 +216,7 @@ impl<P: PoolStrategy, R: JitRuntime, E: JitElement> GpuComputeShaderPhase
vec![
output,
OutputInfo::Array {
item: Item::Scalar(Elem::Int(IntKind::I32)),
item: Item::new(Elem::Int(IntKind::I32)),
},
]
} else {

View File

@ -70,11 +70,11 @@ macro_rules! unary {
let local = scope.last_local_index().unwrap().index().unwrap();
let input = burn_cube::InputInfo::Array {
item: burn_cube::dialect::Item::Scalar(E::cube_elem()),
item: burn_cube::dialect::Item::new(E::cube_elem()),
visibility: burn_cube::dialect::Visibility::Read,
};
let out = burn_cube::OutputInfo::ArrayWrite {
item: burn_cube::dialect::Item::Scalar(E::cube_elem()),
item: burn_cube::dialect::Item::new(E::cube_elem()),
local,
position: burn_cube::dialect::Variable::Id,
};
@ -146,7 +146,7 @@ macro_rules! unary {
let local = scope.last_local_index().unwrap().index().unwrap();
let input = burn_cube::InputInfo::Array {
item: burn_cube::dialect::Item::Scalar(E::cube_elem()),
item: burn_cube::dialect::Item::new(E::cube_elem()),
visibility: burn_cube::dialect::Visibility::Read,
};
let scalars = burn_cube::InputInfo::Scalar {
@ -154,7 +154,7 @@ macro_rules! unary {
size: $num,
};
let out = burn_cube::OutputInfo::ArrayWrite {
item: burn_cube::dialect::Item::Scalar(E::cube_elem()),
item: burn_cube::dialect::Item::new(E::cube_elem()),
local,
position: burn_cube::dialect::Variable::Id,
};

View File

@ -295,7 +295,7 @@ where
fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
unary!(
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Abs(UnaryOperator {
input: scope.read_array(0, Item::Scalar(elem), position),
input: scope.read_array(0, Item::new(elem), position),
out: scope.create_local(elem),
}),
runtime: R,

View File

@ -89,11 +89,13 @@ impl WgslCompiler {
}
fn compile_item(item: cube::Item) -> Item {
match item {
cube::Item::Vec4(elem) => wgsl::Item::Vec4(Self::compile_elem(elem)),
cube::Item::Vec3(elem) => wgsl::Item::Vec3(Self::compile_elem(elem)),
cube::Item::Vec2(elem) => wgsl::Item::Vec2(Self::compile_elem(elem)),
cube::Item::Scalar(elem) => wgsl::Item::Scalar(Self::compile_elem(elem)),
let elem = Self::compile_elem(item.elem);
match item.vectorization {
1 => wgsl::Item::Scalar(elem),
2 => wgsl::Item::Vec2(elem),
3 => wgsl::Item::Vec3(elem),
4 => wgsl::Item::Vec4(elem),
_ => panic!("Unsupported vectorizations scheme {:?}", item.vectorization),
}
}
@ -101,7 +103,7 @@ impl WgslCompiler {
match value {
cube::Elem::Float(f) => match f {
cube::FloatKind::F16 => panic!("f16 is not yet supported"),
cube::FloatKind::BF16 => panic!("f64 is not a valid WgpuElement"),
cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
cube::FloatKind::F32 => wgsl::Elem::F32,
cube::FloatKind::F64 => panic!("f64 is not a valid WgpuElement"),
},