mirror of https://github.com/tracel-ai/burn.git
Cube: support method call + prettier tensor metadata (#1829)
This commit is contained in:
parent
fd54a8b470
commit
e61b026918
|
@ -97,7 +97,9 @@ pub(crate) fn codegen_expr(
|
|||
syn::Expr::Break(_) => codegen_break(),
|
||||
syn::Expr::Return(return_expr) => codegen_return(return_expr),
|
||||
syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_analyses),
|
||||
syn::Expr::MethodCall(call) => codegen_expr_method_call(call),
|
||||
syn::Expr::MethodCall(call) => {
|
||||
codegen_expr_method_call(call, loop_level, variable_analyses)
|
||||
}
|
||||
syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses),
|
||||
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses),
|
||||
syn::Expr::Array(array) => codegen_array_lit(array),
|
||||
|
|
|
@ -1,12 +1,27 @@
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::quote_spanned;
|
||||
use syn::{spanned::Spanned, AngleBracketedGenericArguments, Ident, PathArguments};
|
||||
use syn::{
|
||||
punctuated::Punctuated, spanned::Spanned, AngleBracketedGenericArguments, Expr, Ident,
|
||||
PathArguments, Token,
|
||||
};
|
||||
|
||||
use crate::{analysis::CodeAnalysis, codegen::base::codegen_expr};
|
||||
|
||||
/// Codegen for method call
|
||||
pub(crate) fn codegen_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream {
|
||||
quote::quote!( #call )
|
||||
/// Supports [expr].method(args)
|
||||
pub(crate) fn codegen_expr_method_call(
|
||||
call: &syn::ExprMethodCall,
|
||||
loop_level: usize,
|
||||
variable_analyses: &mut CodeAnalysis,
|
||||
) -> TokenStream {
|
||||
let receiver = codegen_expr(&call.receiver, loop_level, variable_analyses);
|
||||
let method_expand = syn::Ident::new(
|
||||
format!("{}_expand", call.method).as_str(),
|
||||
proc_macro2::Span::call_site(),
|
||||
);
|
||||
let args = codegen_args(&call.args, loop_level, variable_analyses);
|
||||
|
||||
quote::quote!( #receiver . #method_expand ( #args ))
|
||||
}
|
||||
|
||||
/// Codegen for a closure
|
||||
|
@ -110,12 +125,7 @@ pub(crate) fn parse_function_call(
|
|||
quote::quote! {#code}
|
||||
}
|
||||
"unwrap_or_else" => {
|
||||
let mut args = quote::quote! {};
|
||||
args.extend(quote::quote! { context, });
|
||||
for argument in call.args.iter() {
|
||||
let arg = codegen_expr(argument, loop_level, variable_analyses);
|
||||
args.extend(quote::quote! { #arg, });
|
||||
}
|
||||
let args = codegen_args(&call.args, loop_level, variable_analyses);
|
||||
|
||||
// Codegen
|
||||
quote::quote! {
|
||||
|
@ -131,12 +141,7 @@ pub(crate) fn parse_function_call(
|
|||
|
||||
(tokens, true)
|
||||
} else {
|
||||
let mut args = quote::quote! {};
|
||||
args.extend(quote::quote! { context, });
|
||||
for argument in call.args.iter() {
|
||||
let arg = codegen_expr(argument, loop_level, variable_analyses);
|
||||
args.extend(quote::quote! { #arg, });
|
||||
}
|
||||
let args = codegen_args(&call.args, loop_level, variable_analyses);
|
||||
|
||||
// Codegen
|
||||
let tokens = quote::quote! {
|
||||
|
@ -146,3 +151,17 @@ pub(crate) fn parse_function_call(
|
|||
(tokens, false)
|
||||
}
|
||||
}
|
||||
|
||||
fn codegen_args(
|
||||
args: &Punctuated<Expr, Token![,]>,
|
||||
loop_level: usize,
|
||||
variable_analyses: &mut CodeAnalysis,
|
||||
) -> TokenStream {
|
||||
let mut arg_tokens = quote::quote! {};
|
||||
arg_tokens.extend(quote::quote! { context, });
|
||||
for argument in args.iter() {
|
||||
let arg_token = codegen_expr(argument, loop_level, variable_analyses);
|
||||
arg_tokens.extend(quote::quote! { #arg_token, });
|
||||
}
|
||||
arg_tokens
|
||||
}
|
||||
|
|
|
@ -47,56 +47,49 @@ impl<'a, R: Runtime> ArgSettings<R> for TensorHandle<'a, R> {
|
|||
|
||||
impl<T: CubeType> Tensor<T> {
|
||||
/// Obtain the stride of input at dimension dim
|
||||
pub fn stride(_input: Tensor<T>, _dim: u32) -> UInt {
|
||||
pub fn stride(self, _dim: i32) -> UInt {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Obtain the stride of input at dimension dim
|
||||
pub fn stride_expand(
|
||||
context: &mut CubeContext,
|
||||
input: <Tensor<T> as CubeType>::ExpandType,
|
||||
dim: u32,
|
||||
) -> ExpandElement {
|
||||
/// Obtain the shape of input at dimension dim
|
||||
pub fn shape(self, _dim: i32) -> UInt {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Obtain the array length of input
|
||||
pub fn len(self) -> UInt {
|
||||
unexpanded!()
|
||||
}
|
||||
}
|
||||
|
||||
impl ExpandElement {
|
||||
// Expanded version of Tensor::stride
|
||||
pub fn stride_expand(self, context: &mut CubeContext, dim: i32) -> ExpandElement {
|
||||
let out = context.create_local(Item::new(Elem::UInt));
|
||||
context.register(Metadata::Stride {
|
||||
dim: dim.into(),
|
||||
var: input.into(),
|
||||
dim: (dim as u32).into(),
|
||||
var: self.into(),
|
||||
out: out.clone().into(),
|
||||
});
|
||||
out
|
||||
}
|
||||
|
||||
/// Obtain the shape of input at dimension dim
|
||||
pub fn shape(_input: Tensor<T>, _dim: u32) -> UInt {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Obtain the shape of input at dimension dim
|
||||
pub fn shape_expand(
|
||||
context: &mut CubeContext,
|
||||
input: <Tensor<T> as CubeType>::ExpandType,
|
||||
dim: u32,
|
||||
) -> ExpandElement {
|
||||
// Expanded version of Tensor::shape
|
||||
pub fn shape_expand(self, context: &mut CubeContext, dim: i32) -> ExpandElement {
|
||||
let out = context.create_local(Item::new(Elem::UInt));
|
||||
context.register(Metadata::Shape {
|
||||
dim: dim.into(),
|
||||
var: input.into(),
|
||||
dim: (dim as u32).into(),
|
||||
var: self.into(),
|
||||
out: out.clone().into(),
|
||||
});
|
||||
out
|
||||
}
|
||||
|
||||
pub fn len(_input: Self) -> UInt {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn len_expand(
|
||||
context: &mut CubeContext,
|
||||
input: <Tensor<T> as CubeType>::ExpandType,
|
||||
) -> ExpandElement {
|
||||
// Expanded version of Tensor::len
|
||||
pub fn len_expand(self, context: &mut CubeContext) -> ExpandElement {
|
||||
let out = context.create_local(Item::new(Elem::UInt));
|
||||
context.register(Metadata::ArrayLength {
|
||||
var: input.into(),
|
||||
var: self.into(),
|
||||
out: out.clone().into(),
|
||||
});
|
||||
out
|
||||
|
|
|
@ -2,10 +2,9 @@ use burn_cube::{cube, Numeric, Tensor};
|
|||
|
||||
#[cube]
|
||||
fn kernel<T: Numeric>(input: Tensor<T>) {
|
||||
// TODO: not the prettiest to be forced to put T even if useless
|
||||
let _shape = Tensor::<T>::shape(input, 1u32);
|
||||
let _stride = Tensor::<T>::stride(input, 1u32);
|
||||
let _length = Tensor::<T>::len(input);
|
||||
let _shape = input.shape(1);
|
||||
let _stride = input.stride(1);
|
||||
let _length = input.len();
|
||||
}
|
||||
|
||||
mod tests {
|
||||
|
|
|
@ -31,29 +31,23 @@ fn kernel<F: Float>(
|
|||
kernel_size_0_unroll: Comptime<Option<UInt>>,
|
||||
kernel_size_1_unroll: Comptime<Option<UInt>>,
|
||||
) {
|
||||
if AbsoluteIndex::get() >= Tensor::<F>::len(output) {
|
||||
if AbsoluteIndex::get() >= output.len() {
|
||||
return;
|
||||
}
|
||||
|
||||
let in_channels = Tensor::<F>::shape(weight, 1u32);
|
||||
let in_channels = weight.shape(1);
|
||||
|
||||
let kernel_size_0 =
|
||||
Comptime::unwrap_or_else(kernel_size_0_unroll, || Tensor::<F>::shape(weight, 2u32));
|
||||
let kernel_size_0 = Comptime::unwrap_or_else(kernel_size_0_unroll, || weight.shape(2));
|
||||
let unroll_0 = Comptime::is_some(kernel_size_0_unroll);
|
||||
let kernel_size_1 =
|
||||
Comptime::unwrap_or_else(kernel_size_1_unroll, || Tensor::<F>::shape(weight, 3u32));
|
||||
let kernel_size_1 = Comptime::unwrap_or_else(kernel_size_1_unroll, || weight.shape(3));
|
||||
let unroll_1 = Comptime::is_some(kernel_size_1_unroll);
|
||||
|
||||
let b =
|
||||
AbsoluteIndex::get() / Tensor::<F>::stride(output, 0u32) % Tensor::<F>::shape(output, 0u32);
|
||||
let oc =
|
||||
AbsoluteIndex::get() / Tensor::<F>::stride(output, 1u32) % Tensor::<F>::shape(output, 1u32);
|
||||
let oh =
|
||||
AbsoluteIndex::get() / Tensor::<F>::stride(output, 2u32) % Tensor::<F>::shape(output, 2u32);
|
||||
let ow =
|
||||
AbsoluteIndex::get() / Tensor::<F>::stride(output, 3u32) % Tensor::<F>::shape(output, 3u32);
|
||||
let b = AbsoluteIndex::get() / output.stride(0) % output.shape(0);
|
||||
let oc = AbsoluteIndex::get() / output.stride(1) % output.shape(1);
|
||||
let oh = AbsoluteIndex::get() / output.stride(2) % output.shape(2);
|
||||
let ow = AbsoluteIndex::get() / output.stride(3) % output.shape(3);
|
||||
|
||||
let g = (Tensor::<F>::shape(weight, 0u32) + oc) % groups;
|
||||
let g = (weight.shape(0) + oc) % groups;
|
||||
let ic_start = in_channels * g;
|
||||
let ic_end = ic_start + in_channels;
|
||||
let mut sum = bias[oc];
|
||||
|
@ -61,23 +55,23 @@ fn kernel<F: Float>(
|
|||
let ih_base = oh * conv_stride_0;
|
||||
let iw_base = ow * conv_stride_1;
|
||||
|
||||
let weight_stride_1 = Tensor::<F>::stride(weight, 1u32);
|
||||
let weight_stride_2 = Tensor::<F>::stride(weight, 2u32);
|
||||
let weight_stride_3 = Tensor::<F>::stride(weight, 3u32);
|
||||
let weight_stride_1 = weight.stride(1);
|
||||
let weight_stride_2 = weight.stride(2);
|
||||
let weight_stride_3 = weight.stride(3);
|
||||
|
||||
let input_stride_1 = Tensor::<F>::stride(input, 1u32);
|
||||
let input_stride_2 = Tensor::<F>::stride(input, 2u32);
|
||||
let input_stride_3 = Tensor::<F>::stride(input, 3u32);
|
||||
let input_shape_2 = Tensor::<F>::shape(input, 2u32);
|
||||
let input_shape_3 = Tensor::<F>::shape(input, 3u32);
|
||||
let input_stride_1 = input.stride(1);
|
||||
let input_stride_2 = input.stride(2);
|
||||
let input_stride_3 = input.stride(3);
|
||||
let input_shape_2 = input.shape(2);
|
||||
let input_shape_3 = input.shape(3);
|
||||
|
||||
let border_top = padding_0;
|
||||
let border_left = padding_1;
|
||||
let border_bottom = input_shape_2 + padding_0;
|
||||
let border_right = input_shape_3 + padding_1;
|
||||
|
||||
let index_input_0 = b * Tensor::<F>::stride(input, 0u32);
|
||||
let index_weight_0 = oc * Tensor::<F>::stride(weight, 0u32);
|
||||
let index_input_0 = b * input.stride(0);
|
||||
let index_weight_0 = oc * weight.stride(0);
|
||||
|
||||
for ic in range(ic_start, ic_end, Comptime::new(false)) {
|
||||
let index_input_1 = ic * input_stride_1;
|
||||
|
|
Loading…
Reference in New Issue