diff --git a/crates/burn-cube-macros/src/codegen/base.rs b/crates/burn-cube-macros/src/codegen/base.rs index d33ee788c..72727b3ab 100644 --- a/crates/burn-cube-macros/src/codegen/base.rs +++ b/crates/burn-cube-macros/src/codegen/base.rs @@ -97,7 +97,9 @@ pub(crate) fn codegen_expr( syn::Expr::Break(_) => codegen_break(), syn::Expr::Return(return_expr) => codegen_return(return_expr), syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_analyses), - syn::Expr::MethodCall(call) => codegen_expr_method_call(call), + syn::Expr::MethodCall(call) => { + codegen_expr_method_call(call, loop_level, variable_analyses) + } syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses), syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses), syn::Expr::Array(array) => codegen_array_lit(array), diff --git a/crates/burn-cube-macros/src/codegen/function.rs b/crates/burn-cube-macros/src/codegen/function.rs index 2f1e2cb9f..95da2e76e 100644 --- a/crates/burn-cube-macros/src/codegen/function.rs +++ b/crates/burn-cube-macros/src/codegen/function.rs @@ -1,12 +1,27 @@ use proc_macro2::TokenStream; use quote::quote_spanned; -use syn::{spanned::Spanned, AngleBracketedGenericArguments, Ident, PathArguments}; +use syn::{ + punctuated::Punctuated, spanned::Spanned, AngleBracketedGenericArguments, Expr, Ident, + PathArguments, Token, +}; use crate::{analysis::CodeAnalysis, codegen::base::codegen_expr}; /// Codegen for method call -pub(crate) fn codegen_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream { - quote::quote!( #call ) +/// Supports [expr].method(args) +pub(crate) fn codegen_expr_method_call( + call: &syn::ExprMethodCall, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let receiver = codegen_expr(&call.receiver, loop_level, variable_analyses); + let method_expand = syn::Ident::new( + format!("{}_expand", call.method).as_str(), + proc_macro2::Span::call_site(), + ); + let args = codegen_args(&call.args, loop_level, variable_analyses); + + quote::quote!( #receiver . #method_expand ( #args )) } /// Codegen for a closure @@ -110,12 +125,7 @@ pub(crate) fn parse_function_call( quote::quote! {#code} } "unwrap_or_else" => { - let mut args = quote::quote! {}; - args.extend(quote::quote! { context, }); - for argument in call.args.iter() { - let arg = codegen_expr(argument, loop_level, variable_analyses); - args.extend(quote::quote! { #arg, }); - } + let args = codegen_args(&call.args, loop_level, variable_analyses); // Codegen quote::quote! { @@ -131,12 +141,7 @@ pub(crate) fn parse_function_call( (tokens, true) } else { - let mut args = quote::quote! {}; - args.extend(quote::quote! { context, }); - for argument in call.args.iter() { - let arg = codegen_expr(argument, loop_level, variable_analyses); - args.extend(quote::quote! { #arg, }); - } + let args = codegen_args(&call.args, loop_level, variable_analyses); // Codegen let tokens = quote::quote! { @@ -146,3 +151,17 @@ pub(crate) fn parse_function_call( (tokens, false) } } + +fn codegen_args( + args: &Punctuated, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let mut arg_tokens = quote::quote! {}; + arg_tokens.extend(quote::quote! { context, }); + for argument in args.iter() { + let arg_token = codegen_expr(argument, loop_level, variable_analyses); + arg_tokens.extend(quote::quote! { #arg_token, }); + } + arg_tokens +} diff --git a/crates/burn-cube/src/language/element/tensor.rs b/crates/burn-cube/src/language/element/tensor.rs index 2f6b24b42..76f988170 100644 --- a/crates/burn-cube/src/language/element/tensor.rs +++ b/crates/burn-cube/src/language/element/tensor.rs @@ -47,56 +47,49 @@ impl<'a, R: Runtime> ArgSettings for TensorHandle<'a, R> { impl Tensor { /// Obtain the stride of input at dimension dim - pub fn stride(_input: Tensor, _dim: u32) -> UInt { + pub fn stride(self, _dim: i32) -> UInt { unexpanded!() } - /// Obtain the stride of input at dimension dim - pub fn stride_expand( - context: &mut CubeContext, - input: as CubeType>::ExpandType, - dim: u32, - ) -> ExpandElement { + /// Obtain the shape of input at dimension dim + pub fn shape(self, _dim: i32) -> UInt { + unexpanded!() + } + + /// Obtain the array length of input + pub fn len(self) -> UInt { + unexpanded!() + } +} + +impl ExpandElement { + // Expanded version of Tensor::stride + pub fn stride_expand(self, context: &mut CubeContext, dim: i32) -> ExpandElement { let out = context.create_local(Item::new(Elem::UInt)); context.register(Metadata::Stride { - dim: dim.into(), - var: input.into(), + dim: (dim as u32).into(), + var: self.into(), out: out.clone().into(), }); out } - /// Obtain the shape of input at dimension dim - pub fn shape(_input: Tensor, _dim: u32) -> UInt { - unexpanded!() - } - - /// Obtain the shape of input at dimension dim - pub fn shape_expand( - context: &mut CubeContext, - input: as CubeType>::ExpandType, - dim: u32, - ) -> ExpandElement { + // Expanded version of Tensor::shape + pub fn shape_expand(self, context: &mut CubeContext, dim: i32) -> ExpandElement { let out = context.create_local(Item::new(Elem::UInt)); context.register(Metadata::Shape { - dim: dim.into(), - var: input.into(), + dim: (dim as u32).into(), + var: self.into(), out: out.clone().into(), }); out } - pub fn len(_input: Self) -> UInt { - unexpanded!() - } - - pub fn len_expand( - context: &mut CubeContext, - input: as CubeType>::ExpandType, - ) -> ExpandElement { + // Expanded version of Tensor::len + pub fn len_expand(self, context: &mut CubeContext) -> ExpandElement { let out = context.create_local(Item::new(Elem::UInt)); context.register(Metadata::ArrayLength { - var: input.into(), + var: self.into(), out: out.clone().into(), }); out diff --git a/crates/burn-cube/tests/language/tensor.rs b/crates/burn-cube/tests/language/tensor.rs index 976cdd26c..525ae7cb2 100644 --- a/crates/burn-cube/tests/language/tensor.rs +++ b/crates/burn-cube/tests/language/tensor.rs @@ -2,10 +2,9 @@ use burn_cube::{cube, Numeric, Tensor}; #[cube] fn kernel(input: Tensor) { - // TODO: not the prettiest to be forced to put T even if useless - let _shape = Tensor::::shape(input, 1u32); - let _stride = Tensor::::stride(input, 1u32); - let _length = Tensor::::len(input); + let _shape = input.shape(1); + let _stride = input.stride(1); + let _length = input.len(); } mod tests { diff --git a/crates/burn-jit/src/kernel/conv/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d.rs index 60659fbd2..f36885dc9 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d.rs @@ -31,29 +31,23 @@ fn kernel( kernel_size_0_unroll: Comptime>, kernel_size_1_unroll: Comptime>, ) { - if AbsoluteIndex::get() >= Tensor::::len(output) { + if AbsoluteIndex::get() >= output.len() { return; } - let in_channels = Tensor::::shape(weight, 1u32); + let in_channels = weight.shape(1); - let kernel_size_0 = - Comptime::unwrap_or_else(kernel_size_0_unroll, || Tensor::::shape(weight, 2u32)); + let kernel_size_0 = Comptime::unwrap_or_else(kernel_size_0_unroll, || weight.shape(2)); let unroll_0 = Comptime::is_some(kernel_size_0_unroll); - let kernel_size_1 = - Comptime::unwrap_or_else(kernel_size_1_unroll, || Tensor::::shape(weight, 3u32)); + let kernel_size_1 = Comptime::unwrap_or_else(kernel_size_1_unroll, || weight.shape(3)); let unroll_1 = Comptime::is_some(kernel_size_1_unroll); - let b = - AbsoluteIndex::get() / Tensor::::stride(output, 0u32) % Tensor::::shape(output, 0u32); - let oc = - AbsoluteIndex::get() / Tensor::::stride(output, 1u32) % Tensor::::shape(output, 1u32); - let oh = - AbsoluteIndex::get() / Tensor::::stride(output, 2u32) % Tensor::::shape(output, 2u32); - let ow = - AbsoluteIndex::get() / Tensor::::stride(output, 3u32) % Tensor::::shape(output, 3u32); + let b = AbsoluteIndex::get() / output.stride(0) % output.shape(0); + let oc = AbsoluteIndex::get() / output.stride(1) % output.shape(1); + let oh = AbsoluteIndex::get() / output.stride(2) % output.shape(2); + let ow = AbsoluteIndex::get() / output.stride(3) % output.shape(3); - let g = (Tensor::::shape(weight, 0u32) + oc) % groups; + let g = (weight.shape(0) + oc) % groups; let ic_start = in_channels * g; let ic_end = ic_start + in_channels; let mut sum = bias[oc]; @@ -61,23 +55,23 @@ fn kernel( let ih_base = oh * conv_stride_0; let iw_base = ow * conv_stride_1; - let weight_stride_1 = Tensor::::stride(weight, 1u32); - let weight_stride_2 = Tensor::::stride(weight, 2u32); - let weight_stride_3 = Tensor::::stride(weight, 3u32); + let weight_stride_1 = weight.stride(1); + let weight_stride_2 = weight.stride(2); + let weight_stride_3 = weight.stride(3); - let input_stride_1 = Tensor::::stride(input, 1u32); - let input_stride_2 = Tensor::::stride(input, 2u32); - let input_stride_3 = Tensor::::stride(input, 3u32); - let input_shape_2 = Tensor::::shape(input, 2u32); - let input_shape_3 = Tensor::::shape(input, 3u32); + let input_stride_1 = input.stride(1); + let input_stride_2 = input.stride(2); + let input_stride_3 = input.stride(3); + let input_shape_2 = input.shape(2); + let input_shape_3 = input.shape(3); let border_top = padding_0; let border_left = padding_1; let border_bottom = input_shape_2 + padding_0; let border_right = input_shape_3 + padding_1; - let index_input_0 = b * Tensor::::stride(input, 0u32); - let index_weight_0 = oc * Tensor::::stride(weight, 0u32); + let index_input_0 = b * input.stride(0); + let index_weight_0 = oc * weight.stride(0); for ic in range(ic_start, ic_end, Comptime::new(false)) { let index_input_1 = ic * input_stride_1;