mirror of https://github.com/tracel-ai/burn.git
Refactor/cube/vectorization (#1781)
This commit is contained in:
parent
499ff0dd26
commit
76fe0ed881
|
@ -239,6 +239,15 @@ impl CodeAnalysisBuilder {
|
|||
}
|
||||
syn::Expr::Break(_) => {}
|
||||
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
||||
syn::Expr::Array(expr) => {
|
||||
for element in expr.elems.iter() {
|
||||
match element {
|
||||
syn::Expr::Lit(_) => {}
|
||||
_ => todo!("Analysis: only array of literals is supported"),
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
||||
_ => todo!("Analysis: unsupported expr {expr:?}"),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,10 @@ use super::{
|
|||
branch::{codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_while_loop},
|
||||
function::{codegen_call, codegen_closure, codegen_expr_method_call},
|
||||
operation::codegen_binary,
|
||||
variable::{codegen_assign, codegen_index, codegen_lit, codegen_local, codegen_path_rhs},
|
||||
variable::{
|
||||
codegen_array_lit, codegen_assign, codegen_index, codegen_lit, codegen_local,
|
||||
codegen_path_rhs,
|
||||
},
|
||||
};
|
||||
|
||||
/// Codegen for a statement (generally one line)
|
||||
|
@ -59,6 +62,15 @@ pub(crate) fn codegen_expr_block(
|
|||
codegen_block(&block.block, loop_level, variable_analyses)
|
||||
}
|
||||
|
||||
pub(crate) fn codegen_ref(
|
||||
reference: &syn::ExprReference,
|
||||
loop_level: usize,
|
||||
variable_analyses: &mut CodeAnalysis,
|
||||
) -> TokenStream {
|
||||
let inner = codegen_expr(&reference.expr, loop_level, variable_analyses);
|
||||
quote::quote! { & #inner }
|
||||
}
|
||||
|
||||
/// Codegen for expressions
|
||||
/// There are many variants of expression, treated differently
|
||||
pub(crate) fn codegen_expr(
|
||||
|
@ -84,6 +96,8 @@ pub(crate) fn codegen_expr(
|
|||
syn::Expr::MethodCall(call) => codegen_expr_method_call(call),
|
||||
syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses),
|
||||
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses),
|
||||
syn::Expr::Array(array) => codegen_array_lit(array),
|
||||
syn::Expr::Reference(reference) => codegen_ref(reference, loop_level, variable_analyses),
|
||||
_ => panic!("Codegen: Unsupported {:?}", expr),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,12 +34,7 @@ pub(crate) fn codegen_closure(
|
|||
}
|
||||
|
||||
/// Codegen for a function call
|
||||
/// Supports:
|
||||
/// func()
|
||||
/// func::<T>()
|
||||
/// T::func()
|
||||
///
|
||||
/// Should map:
|
||||
/// Maps
|
||||
/// [A[::<...>]?::]^* func[::<...>] (args)
|
||||
/// to
|
||||
/// [A[::<...>]?::]^* func_expand[::<...>] (context, args)
|
||||
|
|
|
@ -19,6 +19,19 @@ pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream {
|
|||
}
|
||||
}
|
||||
|
||||
/// Codegen for arrays of literals
|
||||
pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream {
|
||||
let mut tokens = quote::quote! {};
|
||||
for element in array.elems.iter() {
|
||||
let token = match element {
|
||||
syn::Expr::Lit(lit) => codegen_lit(lit),
|
||||
_ => todo!("Codegen: Only arrays of literals are supported"),
|
||||
};
|
||||
tokens.extend(quote::quote! { #token, });
|
||||
}
|
||||
quote::quote! { [ #tokens ] }
|
||||
}
|
||||
|
||||
/// Codegen for a local declaration (let ...)
|
||||
/// Supports:
|
||||
/// let x = ...
|
||||
|
|
|
@ -81,12 +81,7 @@ impl core::fmt::Display for CompilationSettings {
|
|||
}
|
||||
|
||||
match self.vectorization {
|
||||
Some(vectorization) => match vectorization {
|
||||
Vectorization::Vec4 => f.write_str("v4"),
|
||||
Vectorization::Vec3 => f.write_str("v3"),
|
||||
Vectorization::Vec2 => f.write_str("v2"),
|
||||
Vectorization::Scalar => f.write_str("v1"),
|
||||
}?,
|
||||
Some(vectorization) => f.write_fmt(format_args!("v{}", vectorization))?,
|
||||
None => f.write_str("vn")?,
|
||||
};
|
||||
|
||||
|
@ -154,7 +149,7 @@ impl InputInfo {
|
|||
item,
|
||||
visibility: _,
|
||||
} => *item,
|
||||
InputInfo::Scalar { elem, size: _ } => Item::Scalar(*elem),
|
||||
InputInfo::Scalar { elem, size: _ } => Item::new(*elem),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -252,7 +247,7 @@ impl Compilation {
|
|||
named.push((
|
||||
"info".to_string(),
|
||||
Binding {
|
||||
item: Item::Scalar(Elem::UInt),
|
||||
item: Item::new(Elem::UInt),
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None, // We avoid putting the length here since it will force a new kernel
|
||||
|
@ -300,7 +295,7 @@ impl Compilation {
|
|||
self.named_bindings.push((
|
||||
format!("scalars_{}", elem),
|
||||
Binding {
|
||||
item: Item::Scalar(elem),
|
||||
item: Item::new(elem),
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: Some(size),
|
||||
|
@ -440,11 +435,9 @@ impl Compilation {
|
|||
}
|
||||
|
||||
fn bool_item(ty: Item) -> Item {
|
||||
match ty {
|
||||
Item::Vec4(elem) => Item::Vec4(bool_elem(elem)),
|
||||
Item::Vec3(elem) => Item::Vec3(bool_elem(elem)),
|
||||
Item::Vec2(elem) => Item::Vec2(bool_elem(elem)),
|
||||
Item::Scalar(elem) => Item::Scalar(bool_elem(elem)),
|
||||
Item {
|
||||
elem: bool_elem(ty.elem),
|
||||
vectorization: ty.vectorization,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -94,7 +94,7 @@ impl RangeLoop {
|
|||
func: F,
|
||||
) {
|
||||
let mut scope = parent_scope.child();
|
||||
let index_ty = Item::Scalar(Elem::UInt);
|
||||
let index_ty = Item::new(Elem::UInt);
|
||||
let i = scope.create_local_undeclared(index_ty);
|
||||
|
||||
func(i, &mut scope);
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use crate::codegen::dialect::{macros::cpa, Item, Scope, Variable, Vectorization};
|
||||
use crate::{
|
||||
branch::range,
|
||||
codegen::dialect::{macros::cpa, Scope, Variable, Vectorization},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Assign value to a variable based on a given condition.
|
||||
|
@ -19,14 +22,15 @@ impl ConditionalAssign {
|
|||
let rhs = self.rhs;
|
||||
let out = self.out;
|
||||
|
||||
let index_var = |scope: &mut Scope, var: Variable, index: usize| match var.item() {
|
||||
Item::Scalar(_) => var,
|
||||
_ => {
|
||||
let out = scope.create_local(var.item().elem());
|
||||
cpa!(scope, out = var[index]);
|
||||
out
|
||||
}
|
||||
};
|
||||
let index_var =
|
||||
|scope: &mut Scope, var: Variable, index: usize| match var.item().vectorization == 1 {
|
||||
true => var,
|
||||
false => {
|
||||
let out = scope.create_local(var.item().elem());
|
||||
cpa!(scope, out = var[index]);
|
||||
out
|
||||
}
|
||||
};
|
||||
|
||||
let mut assign_index = |index: usize| {
|
||||
let cond = index_var(scope, cond, index);
|
||||
|
@ -42,29 +46,20 @@ impl ConditionalAssign {
|
|||
}));
|
||||
};
|
||||
|
||||
match out.item() {
|
||||
Item::Vec4(_) => {
|
||||
assign_index(0);
|
||||
assign_index(1);
|
||||
assign_index(2);
|
||||
assign_index(3);
|
||||
}
|
||||
Item::Vec3(_) => {
|
||||
assign_index(0);
|
||||
assign_index(1);
|
||||
assign_index(2);
|
||||
}
|
||||
Item::Vec2(_) => {
|
||||
assign_index(0);
|
||||
assign_index(1);
|
||||
}
|
||||
Item::Scalar(_) => {
|
||||
let vectorization = out.item().vectorization;
|
||||
match vectorization == 1 {
|
||||
true => {
|
||||
cpa!(scope, if (cond).then(|scope| {
|
||||
cpa!(scope, out = lhs);
|
||||
}).else(|scope| {
|
||||
cpa!(scope, out = rhs);
|
||||
}));
|
||||
}
|
||||
false => {
|
||||
for i in range(0u32, vectorization as u32, true) {
|
||||
assign_index(i);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -19,8 +19,8 @@ impl CheckedIndex {
|
|||
let lhs = self.lhs;
|
||||
let rhs = self.rhs;
|
||||
let out = self.out;
|
||||
let array_len = scope.create_local(Item::Scalar(crate::dialect::Elem::UInt));
|
||||
let inside_bound = scope.create_local(Item::Scalar(crate::dialect::Elem::Bool));
|
||||
let array_len = scope.create_local(Item::new(crate::dialect::Elem::UInt));
|
||||
let inside_bound = scope.create_local(Item::new(crate::dialect::Elem::Bool));
|
||||
|
||||
cpa!(scope, array_len = len(lhs));
|
||||
cpa!(scope, inside_bound = rhs < array_len);
|
||||
|
@ -56,8 +56,8 @@ impl CheckedIndexAssign {
|
|||
let lhs = self.lhs;
|
||||
let rhs = self.rhs;
|
||||
let out = self.out;
|
||||
let array_len = scope.create_local(Item::Scalar(Elem::UInt));
|
||||
let inside_bound = scope.create_local(Item::Scalar(Elem::Bool));
|
||||
let array_len = scope.create_local(Item::new(Elem::UInt));
|
||||
let inside_bound = scope.create_local(Item::new(Elem::Bool));
|
||||
|
||||
cpa!(scope, array_len = len(out));
|
||||
cpa!(scope, inside_bound = lhs < array_len);
|
||||
|
|
|
@ -140,17 +140,11 @@ impl IndexOffsetGlobalWithLayout {
|
|||
#[allow(missing_docs)]
|
||||
pub fn expand(self, scope: &mut Scope) {
|
||||
let layout = self.layout;
|
||||
let index_item_ty = Item::Scalar(Elem::UInt);
|
||||
let index_item_ty = Item::new(Elem::UInt);
|
||||
let offset_ref = self.position;
|
||||
let zero: Variable = 0u32.into();
|
||||
let vectorization_factor: Variable = match self.tensors[0].item() {
|
||||
Item::Vec4(_) => 4u32,
|
||||
Item::Vec3(_) => 3u32,
|
||||
Item::Vec2(_) => 2u32,
|
||||
Item::Scalar(_) => 1u32,
|
||||
}
|
||||
.into();
|
||||
|
||||
let vectorization_factor: u8 = self.tensors[0].item().vectorization;
|
||||
let vectorization_factor: Variable = (vectorization_factor as u32).into();
|
||||
for index in self.indexes.iter() {
|
||||
cpa!(scope, index = zero);
|
||||
}
|
||||
|
|
|
@ -336,11 +336,9 @@ impl Scope {
|
|||
position: Variable,
|
||||
) -> Variable {
|
||||
let item_global = match item.elem() {
|
||||
Elem::Bool => match item {
|
||||
Item::Vec4(_) => Item::Vec4(Elem::UInt),
|
||||
Item::Vec3(_) => Item::Vec3(Elem::UInt),
|
||||
Item::Vec2(_) => Item::Vec2(Elem::UInt),
|
||||
Item::Scalar(_) => Item::Scalar(Elem::UInt),
|
||||
Elem::Bool => Item {
|
||||
elem: Elem::UInt,
|
||||
vectorization: item.vectorization,
|
||||
},
|
||||
_ => item,
|
||||
};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::Scope;
|
||||
use super::{Scope, Vectorization};
|
||||
use crate::WORKGROUP_DEFAULT;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Display;
|
||||
|
@ -44,7 +44,7 @@ pub enum Elem {
|
|||
|
||||
impl From<Elem> for Item {
|
||||
fn from(val: Elem) -> Self {
|
||||
Item::Scalar(val)
|
||||
Item::new(val)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -81,22 +81,30 @@ impl Display for Elem {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize, Hash)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum Item {
|
||||
Vec4(Elem),
|
||||
Vec3(Elem),
|
||||
Vec2(Elem),
|
||||
Scalar(Elem),
|
||||
pub struct Item {
|
||||
pub elem: Elem,
|
||||
pub vectorization: Vectorization,
|
||||
}
|
||||
|
||||
impl Item {
|
||||
/// Fetch the elem of the item.
|
||||
pub fn elem(&self) -> Elem {
|
||||
match self {
|
||||
Self::Vec4(elem) => *elem,
|
||||
Self::Vec3(elem) => *elem,
|
||||
Self::Vec2(elem) => *elem,
|
||||
Self::Scalar(elem) => *elem,
|
||||
self.elem
|
||||
}
|
||||
|
||||
/// Create a new item without vectorization
|
||||
pub fn new(elem: Elem) -> Self {
|
||||
Self {
|
||||
elem,
|
||||
vectorization: 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new item with vectorization
|
||||
pub fn vectorized(elem: Elem, vectorization: Vectorization) -> Self {
|
||||
Self {
|
||||
elem,
|
||||
vectorization,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -69,30 +69,30 @@ impl Variable {
|
|||
match self {
|
||||
Variable::GlobalInputArray(_, item) => *item,
|
||||
Variable::GlobalOutputArray(_, item) => *item,
|
||||
Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
|
||||
Variable::GlobalScalar(_, elem) => Item::new(*elem),
|
||||
Variable::Local(_, item, _) => *item,
|
||||
Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem),
|
||||
Variable::ConstantScalar(_, elem) => Item::Scalar(*elem),
|
||||
Variable::LocalScalar(_, elem, _) => Item::new(*elem),
|
||||
Variable::ConstantScalar(_, elem) => Item::new(*elem),
|
||||
Variable::SharedMemory(_, item, _) => *item,
|
||||
Variable::LocalArray(_, item, _, _) => *item,
|
||||
Variable::Id => Item::Scalar(Elem::UInt),
|
||||
Variable::Rank => Item::Scalar(Elem::UInt),
|
||||
Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt),
|
||||
Variable::LocalInvocationIdX => Item::Scalar(Elem::UInt),
|
||||
Variable::LocalInvocationIdY => Item::Scalar(Elem::UInt),
|
||||
Variable::LocalInvocationIdZ => Item::Scalar(Elem::UInt),
|
||||
Variable::WorkgroupIdX => Item::Scalar(Elem::UInt),
|
||||
Variable::WorkgroupIdY => Item::Scalar(Elem::UInt),
|
||||
Variable::WorkgroupIdZ => Item::Scalar(Elem::UInt),
|
||||
Variable::GlobalInvocationIdX => Item::Scalar(Elem::UInt),
|
||||
Variable::GlobalInvocationIdY => Item::Scalar(Elem::UInt),
|
||||
Variable::GlobalInvocationIdZ => Item::Scalar(Elem::UInt),
|
||||
Variable::WorkgroupSizeX => Item::Scalar(Elem::UInt),
|
||||
Variable::WorkgroupSizeY => Item::Scalar(Elem::UInt),
|
||||
Variable::WorkgroupSizeZ => Item::Scalar(Elem::UInt),
|
||||
Variable::NumWorkgroupsX => Item::Scalar(Elem::UInt),
|
||||
Variable::NumWorkgroupsY => Item::Scalar(Elem::UInt),
|
||||
Variable::NumWorkgroupsZ => Item::Scalar(Elem::UInt),
|
||||
Variable::Id => Item::new(Elem::UInt),
|
||||
Variable::Rank => Item::new(Elem::UInt),
|
||||
Variable::LocalInvocationIndex => Item::new(Elem::UInt),
|
||||
Variable::LocalInvocationIdX => Item::new(Elem::UInt),
|
||||
Variable::LocalInvocationIdY => Item::new(Elem::UInt),
|
||||
Variable::LocalInvocationIdZ => Item::new(Elem::UInt),
|
||||
Variable::WorkgroupIdX => Item::new(Elem::UInt),
|
||||
Variable::WorkgroupIdY => Item::new(Elem::UInt),
|
||||
Variable::WorkgroupIdZ => Item::new(Elem::UInt),
|
||||
Variable::GlobalInvocationIdX => Item::new(Elem::UInt),
|
||||
Variable::GlobalInvocationIdY => Item::new(Elem::UInt),
|
||||
Variable::GlobalInvocationIdZ => Item::new(Elem::UInt),
|
||||
Variable::WorkgroupSizeX => Item::new(Elem::UInt),
|
||||
Variable::WorkgroupSizeY => Item::new(Elem::UInt),
|
||||
Variable::WorkgroupSizeZ => Item::new(Elem::UInt),
|
||||
Variable::NumWorkgroupsX => Item::new(Elem::UInt),
|
||||
Variable::NumWorkgroupsY => Item::new(Elem::UInt),
|
||||
Variable::NumWorkgroupsZ => Item::new(Elem::UInt),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,19 +1,6 @@
|
|||
use super::{BinaryOperator, ClampOperator, Item, Operation, Operator, UnaryOperator, Variable};
|
||||
|
||||
/// Define a vectorization scheme.
|
||||
#[allow(dead_code)]
|
||||
#[derive(Copy, Clone, Debug, Default, Hash)]
|
||||
pub enum Vectorization {
|
||||
/// Use vec4 for vectorization.
|
||||
Vec4,
|
||||
/// Use vec3 for vectorization.
|
||||
Vec3,
|
||||
/// Use vec2 for vectorization.
|
||||
Vec2,
|
||||
/// Don't vectorize.
|
||||
#[default]
|
||||
Scalar,
|
||||
}
|
||||
pub type Vectorization = u8;
|
||||
|
||||
impl Operation {
|
||||
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
|
@ -169,21 +156,14 @@ impl Variable {
|
|||
}
|
||||
|
||||
impl Item {
|
||||
pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Item {
|
||||
match vectorize {
|
||||
Vectorization::Vec4 => Item::Vec4(self.elem()),
|
||||
Vectorization::Vec3 => Item::Vec3(self.elem()),
|
||||
Vectorization::Vec2 => Item::Vec2(self.elem()),
|
||||
Vectorization::Scalar => Item::Scalar(self.elem()),
|
||||
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Item {
|
||||
Item {
|
||||
elem: self.elem,
|
||||
vectorization,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn vectorized_size(&self, vectorize: Vectorization, size: u32) -> u32 {
|
||||
match vectorize {
|
||||
Vectorization::Vec4 => size / 4,
|
||||
Vectorization::Vec3 => size / 3,
|
||||
Vectorization::Vec2 => size / 2,
|
||||
Vectorization::Scalar => size,
|
||||
}
|
||||
size / (vectorize as u32)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ pub fn range_expand<F>(
|
|||
}
|
||||
} else {
|
||||
let mut child = context.child();
|
||||
let index_ty = Item::Scalar(Elem::UInt);
|
||||
let index_ty = Item::new(Elem::UInt);
|
||||
let i = child.scope.borrow_mut().create_local_undeclared(index_ty);
|
||||
let i = ExpandElement::new(Rc::new(i));
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::dialect::Elem;
|
||||
use crate::dialect::{Elem, Vectorization};
|
||||
|
||||
use crate::language::{CubeContext, CubeType, ExpandElement, PrimitiveVariable};
|
||||
|
||||
|
@ -34,10 +34,14 @@ impl Bool {
|
|||
impl PrimitiveVariable for Bool {
|
||||
type Primitive = bool;
|
||||
|
||||
fn into_elem() -> Elem {
|
||||
fn as_elem() -> Elem {
|
||||
Elem::Bool
|
||||
}
|
||||
|
||||
fn vectorization(&self) -> Vectorization {
|
||||
self.vectorization
|
||||
}
|
||||
|
||||
fn to_f64(&self) -> f64 {
|
||||
match self.val {
|
||||
true => 1.,
|
||||
|
@ -52,4 +56,13 @@ impl PrimitiveVariable for Bool {
|
|||
fn from_i64(val: i64) -> Self {
|
||||
Self::from_f64(val as f64)
|
||||
}
|
||||
|
||||
fn from_i64_vec(vec: &[i64]) -> Self {
|
||||
Self {
|
||||
// We take only one value, because type implements copy and we can't copy an unknown sized vec
|
||||
// For debugging prefer unvectorized types
|
||||
val: *vec.first().expect("Should be at least one value") > 0,
|
||||
vectorization: vec.len() as u8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ pub trait Cast: PrimitiveVariable {
|
|||
context: &mut CubeContext,
|
||||
value: <Self as CubeType>::ExpandType,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = context.create_local(Item::Scalar(<Self as PrimitiveVariable>::into_elem()));
|
||||
let new_var = context.create_local(Item::new(<Self as PrimitiveVariable>::as_elem()));
|
||||
assign::expand(context, value, new_var.clone());
|
||||
new_var
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::dialect::{Elem, FloatKind, Variable};
|
||||
use crate::dialect::{Elem, FloatKind, Variable, Vectorization};
|
||||
use crate::language::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable};
|
||||
use std::rc::Rc;
|
||||
|
||||
|
@ -13,7 +13,7 @@ macro_rules! impl_float {
|
|||
#[derive(Clone, Copy)]
|
||||
pub struct $type {
|
||||
pub val: <Self as PrimitiveVariable>::Primitive,
|
||||
pub vectorization: usize,
|
||||
pub vectorization: u8,
|
||||
}
|
||||
|
||||
impl CubeType for $type {
|
||||
|
@ -24,10 +24,14 @@ macro_rules! impl_float {
|
|||
type Primitive = f64;
|
||||
|
||||
/// Return the element type to use on GPU
|
||||
fn into_elem() -> Elem {
|
||||
fn as_elem() -> Elem {
|
||||
Elem::Float(FloatKind::$type)
|
||||
}
|
||||
|
||||
fn vectorization(&self) -> Vectorization {
|
||||
self.vectorization.into()
|
||||
}
|
||||
|
||||
fn to_f64(&self) -> f64 {
|
||||
self.val
|
||||
}
|
||||
|
@ -39,6 +43,16 @@ macro_rules! impl_float {
|
|||
fn from_i64(val: i64) -> Self {
|
||||
Self::new(val as f64)
|
||||
}
|
||||
|
||||
fn from_i64_vec(vec: &[i64]) -> Self {
|
||||
Self {
|
||||
// We take only one value, because type implements copy and we can't copy an unknown sized vec
|
||||
// When using CPU-side values for debugging kernels, prefer using unvectorized types
|
||||
val: *vec.first().expect("Should be at least one value")
|
||||
as <Self as PrimitiveVariable>::Primitive,
|
||||
vectorization: vec.len() as u8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for $type {}
|
||||
|
@ -55,7 +69,7 @@ macro_rules! impl_float {
|
|||
_context: &mut CubeContext,
|
||||
val: <Self as PrimitiveVariable>::Primitive,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar(val as f64, Self::into_elem());
|
||||
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
|
||||
ExpandElement::new(Rc::new(new_var))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::dialect::{Elem, IntKind, Variable};
|
||||
use crate::dialect::{Elem, IntKind, Variable, Vectorization};
|
||||
use crate::language::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable};
|
||||
use std::rc::Rc;
|
||||
|
||||
|
@ -13,7 +13,7 @@ macro_rules! impl_int {
|
|||
#[derive(Clone, Copy)]
|
||||
pub struct $type {
|
||||
pub val: <Self as PrimitiveVariable>::Primitive,
|
||||
pub vectorization: usize,
|
||||
pub vectorization: u8,
|
||||
}
|
||||
|
||||
impl CubeType for $type {
|
||||
|
@ -23,10 +23,14 @@ macro_rules! impl_int {
|
|||
impl PrimitiveVariable for $type {
|
||||
type Primitive = i64;
|
||||
|
||||
fn into_elem() -> Elem {
|
||||
fn as_elem() -> Elem {
|
||||
Elem::Int(IntKind::$type)
|
||||
}
|
||||
|
||||
fn vectorization(&self) -> Vectorization {
|
||||
self.vectorization.into()
|
||||
}
|
||||
|
||||
fn to_f64(&self) -> f64 {
|
||||
self.val as f64
|
||||
}
|
||||
|
@ -38,6 +42,16 @@ macro_rules! impl_int {
|
|||
fn from_i64(val: i64) -> Self {
|
||||
Self::new(val)
|
||||
}
|
||||
|
||||
fn from_i64_vec(vec: &[i64]) -> Self {
|
||||
Self {
|
||||
// We take only one value, because type implements copy and we can't copy an unknown sized vec
|
||||
// For debugging prefer unvectorized types
|
||||
val: *vec.first().expect("Should be at least one value")
|
||||
as <Self as PrimitiveVariable>::Primitive,
|
||||
vectorization: vec.len() as u8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for $type {}
|
||||
|
@ -53,7 +67,7 @@ macro_rules! impl_int {
|
|||
_context: &mut CubeContext,
|
||||
val: <Self as PrimitiveVariable>::Primitive,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar(val as f64, Self::into_elem());
|
||||
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
|
||||
ExpandElement::new(Rc::new(new_var))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::dialect::Variable;
|
||||
use crate::dialect::{Item, Variable};
|
||||
use crate::index_assign;
|
||||
use crate::language::{CubeContext, CubeType, ExpandElement, PrimitiveVariable};
|
||||
use std::rc::Rc;
|
||||
|
||||
|
@ -24,9 +25,25 @@ pub trait Numeric:
|
|||
<Self as PrimitiveVariable>::from_i64(val)
|
||||
}
|
||||
|
||||
/// Expand version of lit
|
||||
/// Expand version of from_int
|
||||
fn from_int_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar(val as f64, Self::into_elem());
|
||||
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
|
||||
ExpandElement::new(Rc::new(new_var))
|
||||
}
|
||||
|
||||
fn from_vec(vec: &[i64]) -> Self {
|
||||
<Self as PrimitiveVariable>::from_i64_vec(vec)
|
||||
}
|
||||
|
||||
fn from_vec_expand(context: &mut CubeContext, vec: &[i64]) -> <Self as CubeType>::ExpandType {
|
||||
let mut new_var = context.create_local(Item {
|
||||
elem: Self::as_elem(),
|
||||
vectorization: (vec.len() as u8),
|
||||
});
|
||||
for (i, element) in vec.iter().enumerate() {
|
||||
new_var = index_assign::expand(context, new_var, i.into(), (*element).into());
|
||||
}
|
||||
|
||||
new_var
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::rc::Rc;
|
||||
|
||||
use crate::dialect::{Elem, Variable};
|
||||
use crate::dialect::{Elem, Variable, Vectorization};
|
||||
use crate::language::{CubeType, ExpandElement};
|
||||
|
||||
/// Form of CubeType that encapsulates all primitive types:
|
||||
|
@ -9,12 +9,15 @@ pub trait PrimitiveVariable: CubeType<ExpandType = ExpandElement> {
|
|||
type Primitive;
|
||||
|
||||
/// Return the element type to use on GPU
|
||||
fn into_elem() -> Elem;
|
||||
fn as_elem() -> Elem;
|
||||
fn vectorization(&self) -> Vectorization;
|
||||
|
||||
// For easy CPU-side casting
|
||||
fn to_f64(&self) -> f64;
|
||||
fn from_f64(val: f64) -> Self;
|
||||
fn from_i64(val: i64) -> Self;
|
||||
|
||||
fn from_i64_vec(vec: &[i64]) -> Self;
|
||||
}
|
||||
|
||||
macro_rules! impl_into_expand_element {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::rc::Rc;
|
||||
|
||||
use crate::dialect::{Elem, Variable};
|
||||
use crate::dialect::{Elem, Variable, Vectorization};
|
||||
use crate::language::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
|
@ -18,10 +18,14 @@ impl CubeType for UInt {
|
|||
impl PrimitiveVariable for UInt {
|
||||
type Primitive = u32;
|
||||
|
||||
fn into_elem() -> Elem {
|
||||
fn as_elem() -> Elem {
|
||||
Elem::UInt
|
||||
}
|
||||
|
||||
fn vectorization(&self) -> Vectorization {
|
||||
self.vectorization
|
||||
}
|
||||
|
||||
fn to_f64(&self) -> f64 {
|
||||
self.val as f64
|
||||
}
|
||||
|
@ -33,6 +37,16 @@ impl PrimitiveVariable for UInt {
|
|||
fn from_i64(val: i64) -> Self {
|
||||
Self::new(val as <Self as PrimitiveVariable>::Primitive)
|
||||
}
|
||||
|
||||
fn from_i64_vec(vec: &[i64]) -> Self {
|
||||
Self {
|
||||
// We take only one value, because type implements copy and we can't copy an unknown sized vec
|
||||
// For debugging prefer unvectorized types
|
||||
val: *vec.first().expect("Should be at least one value")
|
||||
as <Self as PrimitiveVariable>::Primitive,
|
||||
vectorization: vec.len() as u8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for UInt {}
|
||||
|
@ -48,7 +62,7 @@ impl UInt {
|
|||
_context: &mut CubeContext,
|
||||
val: <Self as PrimitiveVariable>::Primitive,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar(val as f64, Self::into_elem());
|
||||
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
|
||||
ExpandElement::new(Rc::new(new_var))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,12 +26,13 @@ pub mod index_assign {
|
|||
array: ExpandElement,
|
||||
index: ExpandElement,
|
||||
value: ExpandElement,
|
||||
) {
|
||||
) -> ExpandElement {
|
||||
context.register(Operator::IndexAssign(BinaryOperator {
|
||||
lhs: *index,
|
||||
rhs: *value,
|
||||
out: *array,
|
||||
}))
|
||||
}));
|
||||
array
|
||||
}
|
||||
|
||||
impl<E: CubeType, I: Into<UInt>> core::ops::IndexMut<I> for Array<E> {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::dialect::{BinaryOperator, Elem, Item, Operator, Variable};
|
||||
use crate::dialect::{BinaryOperator, Elem, Item, Operator, Variable, Vectorization};
|
||||
use crate::language::{CubeContext, ExpandElement};
|
||||
|
||||
pub(crate) fn binary_expand<F>(
|
||||
|
@ -14,6 +14,8 @@ where
|
|||
let rhs: Variable = *rhs;
|
||||
|
||||
let item = lhs.item();
|
||||
check_vectorization(item.vectorization, rhs.item().vectorization);
|
||||
|
||||
let out = context.create_local(item);
|
||||
let out_var = *out;
|
||||
|
||||
|
@ -39,13 +41,15 @@ where
|
|||
{
|
||||
let lhs: Variable = *lhs;
|
||||
let rhs: Variable = *rhs;
|
||||
let item = lhs.item();
|
||||
|
||||
let out_item = match lhs.item() {
|
||||
Item::Vec4(_) => Item::Vec4(Elem::Bool),
|
||||
Item::Vec3(_) => Item::Vec3(Elem::Bool),
|
||||
Item::Vec2(_) => Item::Vec2(Elem::Bool),
|
||||
Item::Scalar(_) => Item::Scalar(Elem::Bool),
|
||||
check_vectorization(item.vectorization, rhs.item().vectorization);
|
||||
|
||||
let out_item = Item {
|
||||
elem: Elem::Bool,
|
||||
vectorization: item.vectorization,
|
||||
};
|
||||
|
||||
let out = context.create_local(out_item);
|
||||
let out_var = *out;
|
||||
|
||||
|
@ -82,3 +86,13 @@ where
|
|||
|
||||
lhs
|
||||
}
|
||||
|
||||
fn check_vectorization(lhs: Vectorization, rhs: Vectorization) {
|
||||
if lhs == 1 || rhs == 1 {
|
||||
return;
|
||||
}
|
||||
assert!(
|
||||
lhs == rhs,
|
||||
"Tried to perform binary operation on different vectorization schemes."
|
||||
);
|
||||
}
|
||||
|
|
|
@ -151,109 +151,109 @@ mod tests {
|
|||
cast_test!(
|
||||
cube_float_to_float_test,
|
||||
float_to_float_expand,
|
||||
Item::Scalar(F32::into_elem())
|
||||
Item::new(F32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_float_to_int_test,
|
||||
float_to_int_expand,
|
||||
Item::Scalar(F32::into_elem()),
|
||||
Item::Scalar(I32::into_elem())
|
||||
Item::new(F32::as_elem()),
|
||||
Item::new(I32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_float_to_uint_test,
|
||||
float_to_uint_expand,
|
||||
Item::Scalar(F32::into_elem()),
|
||||
Item::Scalar(Elem::UInt)
|
||||
Item::new(F32::as_elem()),
|
||||
Item::new(Elem::UInt)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_float_to_bool_test,
|
||||
float_to_bool_expand,
|
||||
Item::Scalar(F32::into_elem()),
|
||||
Item::Scalar(Elem::Bool)
|
||||
Item::new(F32::as_elem()),
|
||||
Item::new(Elem::Bool)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_int_to_float_test,
|
||||
int_to_float_expand,
|
||||
Item::Scalar(I32::into_elem()),
|
||||
Item::Scalar(F32::into_elem())
|
||||
Item::new(I32::as_elem()),
|
||||
Item::new(F32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_int_to_int_test,
|
||||
int_to_int_expand,
|
||||
Item::Scalar(I32::into_elem())
|
||||
Item::new(I32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_int_to_uint_test,
|
||||
int_to_uint_expand,
|
||||
Item::Scalar(I32::into_elem()),
|
||||
Item::Scalar(Elem::UInt)
|
||||
Item::new(I32::as_elem()),
|
||||
Item::new(Elem::UInt)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_int_to_bool_test,
|
||||
int_to_bool_expand,
|
||||
Item::Scalar(I32::into_elem()),
|
||||
Item::Scalar(Elem::Bool)
|
||||
Item::new(I32::as_elem()),
|
||||
Item::new(Elem::Bool)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_uint_to_float_test,
|
||||
uint_to_float_expand,
|
||||
Item::Scalar(Elem::UInt),
|
||||
Item::Scalar(F32::into_elem())
|
||||
Item::new(Elem::UInt),
|
||||
Item::new(F32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_uint_to_int_test,
|
||||
uint_to_int_expand,
|
||||
Item::Scalar(Elem::UInt),
|
||||
Item::Scalar(I32::into_elem())
|
||||
Item::new(Elem::UInt),
|
||||
Item::new(I32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_uint_to_uint_test,
|
||||
uint_to_uint_expand,
|
||||
Item::Scalar(Elem::UInt)
|
||||
Item::new(Elem::UInt)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_uint_to_bool_test,
|
||||
uint_to_bool_expand,
|
||||
Item::Scalar(Elem::UInt),
|
||||
Item::Scalar(Elem::Bool)
|
||||
Item::new(Elem::UInt),
|
||||
Item::new(Elem::Bool)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_bool_to_float_test,
|
||||
bool_to_float_expand,
|
||||
Item::Scalar(Elem::Bool),
|
||||
Item::Scalar(F32::into_elem())
|
||||
Item::new(Elem::Bool),
|
||||
Item::new(F32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_bool_to_int_test,
|
||||
bool_to_int_expand,
|
||||
Item::Scalar(Elem::Bool),
|
||||
Item::Scalar(I32::into_elem())
|
||||
Item::new(Elem::Bool),
|
||||
Item::new(I32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_bool_to_uint_test,
|
||||
bool_to_uint_expand,
|
||||
Item::Scalar(Elem::Bool),
|
||||
Item::Scalar(Elem::UInt)
|
||||
Item::new(Elem::Bool),
|
||||
Item::new(Elem::UInt)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_bool_to_bool_test,
|
||||
bool_to_bool_expand,
|
||||
Item::Scalar(Elem::Bool)
|
||||
Item::new(Elem::Bool)
|
||||
);
|
||||
|
||||
fn inline_macro_ref_cast(from_item: Item, to_item: Item) -> String {
|
||||
|
|
|
@ -35,7 +35,7 @@ mod tests {
|
|||
#[test]
|
||||
fn cube_cast_float_kind_test() {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(F64::into_elem());
|
||||
let item = Item::new(F64::as_elem());
|
||||
|
||||
let input = context.create_local(item);
|
||||
|
||||
|
@ -48,7 +48,7 @@ mod tests {
|
|||
#[test]
|
||||
fn cube_cast_int_kind_test() {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(I32::into_elem());
|
||||
let item = Item::new(I32::as_elem());
|
||||
|
||||
let input = context.create_local(item);
|
||||
|
||||
|
@ -61,7 +61,7 @@ mod tests {
|
|||
#[test]
|
||||
fn cube_cast_numeric_kind_test() {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(I32::into_elem());
|
||||
let item = Item::new(I32::as_elem());
|
||||
|
||||
let input = context.create_local(item);
|
||||
|
||||
|
@ -74,7 +74,7 @@ mod tests {
|
|||
#[test]
|
||||
fn cube_cast_kind_numeric_test() {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(I32::into_elem());
|
||||
let item = Item::new(I32::as_elem());
|
||||
|
||||
let input = context.create_local(item);
|
||||
|
||||
|
@ -86,8 +86,8 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref_float() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let float_64 = Item::Scalar(F64::into_elem());
|
||||
let float_32 = Item::Scalar(F32::into_elem());
|
||||
let float_64 = Item::new(F64::as_elem());
|
||||
let float_32 = Item::new(F32::as_elem());
|
||||
let input = context.create_local(float_64);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
|
@ -104,8 +104,8 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref_int() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let int_32 = Item::Scalar(I32::into_elem());
|
||||
let int_64 = Item::Scalar(I64::into_elem());
|
||||
let int_32 = Item::new(I32::as_elem());
|
||||
let int_64 = Item::new(I64::as_elem());
|
||||
let input = context.create_local(int_32);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
|
|
|
@ -25,8 +25,8 @@ mod tests {
|
|||
let mut context = CubeContext::root();
|
||||
let unroll = true;
|
||||
|
||||
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let rhs = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let rhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let end = 4u32.into();
|
||||
|
||||
for_loop_expand::<ElemType>(&mut context, lhs, rhs, end, unroll);
|
||||
|
@ -40,8 +40,8 @@ mod tests {
|
|||
let mut context = CubeContext::root();
|
||||
let unroll = false;
|
||||
|
||||
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let rhs = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let rhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let end = 4u32.into();
|
||||
|
||||
for_loop_expand::<ElemType>(&mut context, lhs, rhs, end, unroll);
|
||||
|
@ -52,7 +52,7 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref(unroll: bool) -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(ElemType::into_elem());
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
|
||||
let lhs = context.create_local(item);
|
||||
let rhs = context.create_local(item);
|
||||
|
|
|
@ -55,12 +55,12 @@ mod tests {
|
|||
#[test]
|
||||
fn cube_call_equivalent_to_no_call_no_arg_test() {
|
||||
let mut caller_context = CubeContext::root();
|
||||
let x = caller_context.create_local(Item::Scalar(Elem::UInt));
|
||||
let x = caller_context.create_local(Item::new(Elem::UInt));
|
||||
caller_no_arg_expand(&mut caller_context, x);
|
||||
let caller_scope = caller_context.into_scope();
|
||||
|
||||
let mut no_call_context = CubeContext::root();
|
||||
let x = no_call_context.create_local(Item::Scalar(Elem::UInt));
|
||||
let x = no_call_context.create_local(Item::new(Elem::UInt));
|
||||
no_call_no_arg_expand(&mut no_call_context, x);
|
||||
let no_call_scope = no_call_context.into_scope();
|
||||
|
||||
|
@ -74,12 +74,12 @@ mod tests {
|
|||
fn cube_call_equivalent_to_no_call_with_arg_test() {
|
||||
let mut caller_context = CubeContext::root();
|
||||
|
||||
let x = caller_context.create_local(Item::Scalar(Elem::UInt));
|
||||
let x = caller_context.create_local(Item::new(Elem::UInt));
|
||||
caller_with_arg_expand(&mut caller_context, x);
|
||||
let caller_scope = caller_context.into_scope();
|
||||
|
||||
let mut no_call_context = CubeContext::root();
|
||||
let x = no_call_context.create_local(Item::Scalar(Elem::UInt));
|
||||
let x = no_call_context.create_local(Item::new(Elem::UInt));
|
||||
no_call_with_arg_expand(&mut no_call_context, x);
|
||||
let no_call_scope = no_call_context.into_scope();
|
||||
|
||||
|
@ -93,12 +93,12 @@ mod tests {
|
|||
fn cube_call_equivalent_to_no_call_with_generics_test() {
|
||||
let mut caller_context = CubeContext::root();
|
||||
type ElemType = I64;
|
||||
let x = caller_context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let x = caller_context.create_local(Item::new(ElemType::as_elem()));
|
||||
caller_with_generics_expand::<ElemType>(&mut caller_context, x);
|
||||
let caller_scope = caller_context.into_scope();
|
||||
|
||||
let mut no_call_context = CubeContext::root();
|
||||
let x = no_call_context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let x = no_call_context.create_local(Item::new(ElemType::as_elem()));
|
||||
no_call_with_generics_expand::<ElemType>(&mut no_call_context, x);
|
||||
let no_call_scope = no_call_context.into_scope();
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ mod tests {
|
|||
fn cube_generic_float_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::Scalar(F32::into_elem()));
|
||||
let lhs = context.create_local(Item::new(F32::as_elem()));
|
||||
|
||||
generic_kernel_expand::<F32>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
@ -26,7 +26,7 @@ mod tests {
|
|||
fn cube_generic_int_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::Scalar(I32::into_elem()));
|
||||
let lhs = context.create_local(Item::new(I32::as_elem()));
|
||||
|
||||
generic_kernel_expand::<I32>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
@ -36,7 +36,7 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref_float() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(F32::into_elem());
|
||||
let item = Item::new(F32::as_elem());
|
||||
let lhs = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
|
@ -48,7 +48,7 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref_int() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(I32::into_elem());
|
||||
let item = Item::new(I32::as_elem());
|
||||
let lhs = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
|
|
|
@ -22,7 +22,7 @@ mod tests {
|
|||
fn cube_if_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
if_greater_expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
@ -32,11 +32,11 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(ElemType::into_elem());
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let lhs = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let cond = scope.create_local(Item::Scalar(Elem::Bool));
|
||||
let cond = scope.create_local(Item::new(Elem::Bool));
|
||||
let lhs: Variable = lhs.into();
|
||||
let y = scope.create_local(item);
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ mod tests {
|
|||
fn cube_if_else_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
if_then_else_expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
@ -34,11 +34,11 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(ElemType::into_elem());
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let lhs = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let cond = scope.create_local(Item::Scalar(Elem::Bool));
|
||||
let cond = scope.create_local(Item::new(Elem::Bool));
|
||||
let lhs: Variable = lhs.into();
|
||||
let y = scope.create_local(item);
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ mod tests {
|
|||
fn cube_literal_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
literal_expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
@ -32,7 +32,7 @@ mod tests {
|
|||
fn cube_literal_float_no_decimal_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
literal_float_no_decimals_expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
@ -42,7 +42,7 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(ElemType::into_elem());
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let lhs = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
|
|
|
@ -32,7 +32,7 @@ mod tests {
|
|||
fn cube_while_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
while_not_expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
@ -44,7 +44,7 @@ mod tests {
|
|||
fn cube_loop_break_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
manual_loop_break_expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
@ -54,11 +54,11 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(ElemType::into_elem());
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let lhs = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let cond = scope.create_local(Item::Scalar(Elem::Bool));
|
||||
let cond = scope.create_local(Item::new(Elem::Bool));
|
||||
let lhs: Variable = lhs.into();
|
||||
let rhs = scope.create_local(item);
|
||||
|
||||
|
|
|
@ -11,3 +11,4 @@ mod module_import;
|
|||
mod parenthesis;
|
||||
mod reuse;
|
||||
mod r#trait;
|
||||
mod vectorization;
|
||||
|
|
|
@ -33,12 +33,12 @@ mod tests {
|
|||
#[test]
|
||||
fn cube_call_equivalent_to_no_call_no_arg_test() {
|
||||
let mut caller_context = CubeContext::root();
|
||||
let x = caller_context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let x = caller_context.create_local(Item::new(ElemType::as_elem()));
|
||||
here::caller_expand::<ElemType>(&mut caller_context, x);
|
||||
let caller_scope = caller_context.into_scope();
|
||||
|
||||
let mut no_call_context = CubeContext::root();
|
||||
let x = no_call_context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let x = no_call_context.create_local(Item::new(ElemType::as_elem()));
|
||||
here::no_call_ref_expand::<ElemType>(&mut no_call_context, x);
|
||||
let no_call_scope = no_call_context.into_scope();
|
||||
|
||||
|
|
|
@ -20,9 +20,9 @@ mod tests {
|
|||
fn cube_parenthesis_priority_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let y = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let z = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let y = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let z = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
parenthesis_expand::<ElemType>(&mut context, x, y, z);
|
||||
let scope = context.into_scope();
|
||||
|
@ -32,7 +32,7 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(ElemType::into_elem());
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let x = context.create_local(item);
|
||||
let y = context.create_local(item);
|
||||
let z = context.create_local(item);
|
||||
|
|
|
@ -32,7 +32,7 @@ mod tests {
|
|||
fn cube_reuse_assign_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
reuse_expand::<ElemType>(&mut context, x);
|
||||
let scope = context.into_scope();
|
||||
|
@ -44,7 +44,7 @@ mod tests {
|
|||
fn cube_reuse_incr_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
reuse_incr_expand::<ElemType>(&mut context, x);
|
||||
let scope = context.into_scope();
|
||||
|
@ -54,11 +54,11 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref_assign() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(ElemType::into_elem());
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let x = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let cond = scope.create_local(Item::Scalar(Elem::Bool));
|
||||
let cond = scope.create_local(Item::new(Elem::Bool));
|
||||
let x: Variable = x.into();
|
||||
let tmp = scope.create_local(item);
|
||||
|
||||
|
@ -80,11 +80,11 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref_incr() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(ElemType::into_elem());
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let x = context.create_local(item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let cond = scope.create_local(Item::Scalar(Elem::Bool));
|
||||
let cond = scope.create_local(Item::new(Elem::Bool));
|
||||
let x: Variable = x.into();
|
||||
|
||||
cpa!(
|
||||
|
|
|
@ -114,8 +114,8 @@ mod tests {
|
|||
fn cube_strategy_trait_add_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let y = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let y = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
with_strategy_trait_expand::<AddStrategy, ElemType>(&mut context, x, y);
|
||||
let scope = context.into_scope();
|
||||
|
@ -130,8 +130,8 @@ mod tests {
|
|||
fn cube_strategy_trait_sub_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let y = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let y = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
with_strategy_trait_expand::<SubStrategy, ElemType>(&mut context, x, y);
|
||||
let scope = context.into_scope();
|
||||
|
@ -146,8 +146,8 @@ mod tests {
|
|||
fn cube_two_strategy_traits_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let y = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let y = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
two_strategy_traits_expand::<SubStrategy, AddStrategy, ElemType>(&mut context, x, y);
|
||||
let scope = context.into_scope();
|
||||
|
@ -159,8 +159,8 @@ mod tests {
|
|||
fn cube_trait_generic_method_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let x = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let y = context.create_local(Item::Scalar(ElemType::into_elem()));
|
||||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let y = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
with_trait_generic_method_expand::<AddStrategy, ElemType>(&mut context, x, y);
|
||||
let scope = context.into_scope();
|
||||
|
@ -173,7 +173,7 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref_one(is_add_strategy: bool) -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(ElemType::into_elem());
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let x = context.create_local(item);
|
||||
let y = context.create_local(item);
|
||||
|
||||
|
@ -192,7 +192,7 @@ mod tests {
|
|||
|
||||
fn inline_macro_ref_two() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::Scalar(ElemType::into_elem());
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let x = context.create_local(item);
|
||||
let y = context.create_local(item);
|
||||
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
use burn_cube::{cube, Numeric};
|
||||
|
||||
#[cube]
|
||||
pub fn vectorization_binary<T: Numeric>(lhs: T) {
|
||||
let _ = lhs + T::from_vec(&[4, 5]);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub fn vectorization_cmp<T: Numeric>(rhs: T) {
|
||||
let _ = T::from_vec(&[4, 5]) > rhs;
|
||||
}
|
||||
|
||||
mod tests {
|
||||
|
||||
use burn_cube::{dialect::Item, CubeContext, PrimitiveVariable, F32};
|
||||
|
||||
use crate::language::vectorization::{vectorization_binary_expand, vectorization_cmp_expand};
|
||||
|
||||
type ElemType = F32;
|
||||
|
||||
#[test]
|
||||
fn cube_vectorization_binary_op_with_same_scheme_does_not_fail() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2));
|
||||
|
||||
vectorization_binary_expand::<ElemType>(&mut context, lhs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn cube_vectorization_binary_op_with_different_scheme_fails() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4));
|
||||
|
||||
vectorization_binary_expand::<ElemType>(&mut context, lhs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cube_vectorization_cmp_op_with_same_scheme_does_not_fail() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2));
|
||||
|
||||
vectorization_cmp_expand::<ElemType>(&mut context, lhs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn cube_vectorization_cmp_op_with_different_scheme_fails() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4));
|
||||
|
||||
vectorization_cmp_expand::<ElemType>(&mut context, lhs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cube_vectorization_can_be_broadcasted() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 1));
|
||||
|
||||
vectorization_cmp_expand::<ElemType>(&mut context, lhs);
|
||||
}
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
use burn_cube::{
|
||||
calculate_num_elems_dyn_rank,
|
||||
dialect::{Vectorization, WorkgroupSize},
|
||||
elemwise_workgroup, CompilationInfo, CompilationSettings,
|
||||
calculate_num_elems_dyn_rank, dialect::WorkgroupSize, elemwise_workgroup, CompilationInfo,
|
||||
CompilationSettings,
|
||||
};
|
||||
use burn_tensor::repr::TensorDescription;
|
||||
|
||||
|
@ -56,12 +55,10 @@ impl<R: JitRuntime> FusionKernelFactory<R> for ElementWiseKernelFactory<R> {
|
|||
);
|
||||
|
||||
if vectorize_4 {
|
||||
settings = settings.vectorize(Vectorization::Vec4);
|
||||
settings = settings.vectorize(4);
|
||||
factor = 4;
|
||||
}
|
||||
|
||||
if !vectorize_4 && vectorize_2 {
|
||||
settings = settings.vectorize(Vectorization::Vec2);
|
||||
} else if vectorize_2 {
|
||||
settings = settings.vectorize(2);
|
||||
factor = 2;
|
||||
}
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ impl TraceBuilder {
|
|||
false => {
|
||||
// New input
|
||||
let index = self.inputs.len() as u16;
|
||||
let item = Item::Scalar(elem);
|
||||
let item = Item::new(elem);
|
||||
|
||||
let local = self.scope.read_array(index, item, position);
|
||||
self.inputs.push((tensor.clone(), local));
|
||||
|
@ -56,7 +56,7 @@ impl TraceBuilder {
|
|||
true => match self.output_to_local.get(&tensor.id) {
|
||||
// Is a local variable.
|
||||
Some(local_index) => {
|
||||
Variable::Local(*local_index, Item::Scalar(elem), self.scope.depth)
|
||||
Variable::Local(*local_index, Item::new(elem), self.scope.depth)
|
||||
}
|
||||
// Isn't an operation output variable, so must be an existing input.
|
||||
None => self
|
||||
|
@ -84,10 +84,10 @@ impl TraceBuilder {
|
|||
|
||||
// Output already registered as a local variable.
|
||||
if let Some(index) = self.output_to_local.get(&tensor.id) {
|
||||
return Variable::Local(*index, Item::Scalar(elem), self.scope.depth);
|
||||
return Variable::Local(*index, Item::new(elem), self.scope.depth);
|
||||
}
|
||||
|
||||
let variable = self.scope.create_local(Item::Scalar(elem));
|
||||
let variable = self.scope.create_local(Item::new(elem));
|
||||
let local_index = variable.index().unwrap();
|
||||
self.output_to_local.insert(tensor.id, local_index);
|
||||
variable
|
||||
|
|
|
@ -45,7 +45,7 @@ impl Trace {
|
|||
.inputs
|
||||
.iter()
|
||||
.map(|(_tensor, elem, _)| InputInfo::Array {
|
||||
item: Item::Scalar(*elem),
|
||||
item: Item::new(*elem),
|
||||
visibility: Visibility::Read,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -56,7 +56,7 @@ impl Trace {
|
|||
.zip(self.locals.iter())
|
||||
.map(
|
||||
|((_tensor, elem, index_ref), local)| OutputInfo::ArrayWrite {
|
||||
item: Item::Scalar(*elem),
|
||||
item: Item::new(*elem),
|
||||
local: *local,
|
||||
position: *index_ref,
|
||||
},
|
||||
|
|
|
@ -65,15 +65,15 @@ macro_rules! binary {
|
|||
let local = scope.last_local_index().unwrap().index().unwrap();
|
||||
|
||||
let lhs = burn_cube::InputInfo::Array {
|
||||
item: burn_cube::dialect::Item::Scalar(I::cube_elem()),
|
||||
item: burn_cube::dialect::Item::new(I::cube_elem()),
|
||||
visibility: burn_cube::dialect::Visibility::Read,
|
||||
};
|
||||
let rhs = burn_cube::InputInfo::Array {
|
||||
item: burn_cube::dialect::Item::Scalar(I::cube_elem()),
|
||||
item: burn_cube::dialect::Item::new(I::cube_elem()),
|
||||
visibility: burn_cube::dialect::Visibility::Read,
|
||||
};
|
||||
let out = burn_cube::OutputInfo::ArrayWrite {
|
||||
item: burn_cube::dialect::Item::Scalar(O::cube_elem()),
|
||||
item: burn_cube::dialect::Item::new(O::cube_elem()),
|
||||
local,
|
||||
position,
|
||||
};
|
||||
|
|
|
@ -56,7 +56,7 @@ pub(crate) struct BoolCastEagerKernel<R: JitRuntime, EO: JitElement> {
|
|||
impl<R: JitRuntime, EO: JitElement> GpuComputeShaderPhase for BoolCastEagerKernel<R, EO> {
|
||||
fn compile(&self) -> ComputeShader {
|
||||
let mut scope = Scope::root();
|
||||
let item_input = Item::Scalar(Elem::Bool);
|
||||
let item_input = Item::new(Elem::Bool);
|
||||
let item_output = EO::cube_elem().into();
|
||||
|
||||
let tensor = Variable::GlobalInputArray(0, item_input);
|
||||
|
|
|
@ -125,14 +125,14 @@ pub(crate) fn gather_shader_information(
|
|||
|
||||
// Registers used in the compute pass
|
||||
let results = scope.create_local_array(elem, results_size);
|
||||
let register_m = scope.create_local(Item::Vec4(elem));
|
||||
let register_n = scope.create_local(Item::Vec4(elem));
|
||||
let register_m = scope.create_local(Item::vectorized(elem, 4));
|
||||
let register_n = scope.create_local(Item::vectorized(elem, 4));
|
||||
let shared_lhs = scope.create_shared(
|
||||
Item::Vec4(elem),
|
||||
Item::vectorized(elem, 4),
|
||||
shader.config.block_size_m as u32 * shader.config.block_size_k as u32 / 4u32,
|
||||
);
|
||||
let shared_rhs = scope.create_shared(
|
||||
Item::Vec4(elem),
|
||||
Item::vectorized(elem, 4),
|
||||
shader.config.block_size_k as u32 * shader.config.block_size_n as u32 / 4u32,
|
||||
);
|
||||
|
||||
|
|
|
@ -268,7 +268,7 @@ impl<R: JitRuntime, E: JitElement> GpuComputeShaderPhase
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let indices = Variable::GlobalInputArray(0, Item::Scalar(Elem::Int(IntKind::I32)));
|
||||
let indices = Variable::GlobalInputArray(0, Item::new(Elem::Int(IntKind::I32)));
|
||||
let grad = Variable::GlobalInputArray(1, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
|
||||
|
@ -283,7 +283,7 @@ impl<R: JitRuntime, E: JitElement> GpuComputeShaderPhase
|
|||
.expand(&mut scope);
|
||||
|
||||
let indices = InputInfo::Array {
|
||||
item: Item::Scalar(Elem::Int(IntKind::I32)),
|
||||
item: Item::new(Elem::Int(IntKind::I32)),
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
|
||||
|
|
|
@ -184,7 +184,7 @@ impl<P: PoolStrategy, R: JitRuntime, E: JitElement> GpuComputeShaderPhase
|
|||
let indices = if P::with_indices() {
|
||||
Some(Variable::GlobalOutputArray(
|
||||
1,
|
||||
Item::Scalar(Elem::Int(IntKind::I32)),
|
||||
Item::new(Elem::Int(IntKind::I32)),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
|
@ -216,7 +216,7 @@ impl<P: PoolStrategy, R: JitRuntime, E: JitElement> GpuComputeShaderPhase
|
|||
vec![
|
||||
output,
|
||||
OutputInfo::Array {
|
||||
item: Item::Scalar(Elem::Int(IntKind::I32)),
|
||||
item: Item::new(Elem::Int(IntKind::I32)),
|
||||
},
|
||||
]
|
||||
} else {
|
||||
|
|
|
@ -70,11 +70,11 @@ macro_rules! unary {
|
|||
let local = scope.last_local_index().unwrap().index().unwrap();
|
||||
|
||||
let input = burn_cube::InputInfo::Array {
|
||||
item: burn_cube::dialect::Item::Scalar(E::cube_elem()),
|
||||
item: burn_cube::dialect::Item::new(E::cube_elem()),
|
||||
visibility: burn_cube::dialect::Visibility::Read,
|
||||
};
|
||||
let out = burn_cube::OutputInfo::ArrayWrite {
|
||||
item: burn_cube::dialect::Item::Scalar(E::cube_elem()),
|
||||
item: burn_cube::dialect::Item::new(E::cube_elem()),
|
||||
local,
|
||||
position: burn_cube::dialect::Variable::Id,
|
||||
};
|
||||
|
@ -146,7 +146,7 @@ macro_rules! unary {
|
|||
let local = scope.last_local_index().unwrap().index().unwrap();
|
||||
|
||||
let input = burn_cube::InputInfo::Array {
|
||||
item: burn_cube::dialect::Item::Scalar(E::cube_elem()),
|
||||
item: burn_cube::dialect::Item::new(E::cube_elem()),
|
||||
visibility: burn_cube::dialect::Visibility::Read,
|
||||
};
|
||||
let scalars = burn_cube::InputInfo::Scalar {
|
||||
|
@ -154,7 +154,7 @@ macro_rules! unary {
|
|||
size: $num,
|
||||
};
|
||||
let out = burn_cube::OutputInfo::ArrayWrite {
|
||||
item: burn_cube::dialect::Item::Scalar(E::cube_elem()),
|
||||
item: burn_cube::dialect::Item::new(E::cube_elem()),
|
||||
local,
|
||||
position: burn_cube::dialect::Variable::Id,
|
||||
};
|
||||
|
|
|
@ -295,7 +295,7 @@ where
|
|||
fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Abs(UnaryOperator {
|
||||
input: scope.read_array(0, Item::Scalar(elem), position),
|
||||
input: scope.read_array(0, Item::new(elem), position),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
|
|
|
@ -89,11 +89,13 @@ impl WgslCompiler {
|
|||
}
|
||||
|
||||
fn compile_item(item: cube::Item) -> Item {
|
||||
match item {
|
||||
cube::Item::Vec4(elem) => wgsl::Item::Vec4(Self::compile_elem(elem)),
|
||||
cube::Item::Vec3(elem) => wgsl::Item::Vec3(Self::compile_elem(elem)),
|
||||
cube::Item::Vec2(elem) => wgsl::Item::Vec2(Self::compile_elem(elem)),
|
||||
cube::Item::Scalar(elem) => wgsl::Item::Scalar(Self::compile_elem(elem)),
|
||||
let elem = Self::compile_elem(item.elem);
|
||||
match item.vectorization {
|
||||
1 => wgsl::Item::Scalar(elem),
|
||||
2 => wgsl::Item::Vec2(elem),
|
||||
3 => wgsl::Item::Vec3(elem),
|
||||
4 => wgsl::Item::Vec4(elem),
|
||||
_ => panic!("Unsupported vectorizations scheme {:?}", item.vectorization),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,7 +103,7 @@ impl WgslCompiler {
|
|||
match value {
|
||||
cube::Elem::Float(f) => match f {
|
||||
cube::FloatKind::F16 => panic!("f16 is not yet supported"),
|
||||
cube::FloatKind::BF16 => panic!("f64 is not a valid WgpuElement"),
|
||||
cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"),
|
||||
cube::FloatKind::F32 => wgsl::Elem::F32,
|
||||
cube::FloatKind::F64 => panic!("f64 is not a valid WgpuElement"),
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue