Cube: support method call + prettier tensor metadata (#1829)

This commit is contained in:
Louis Fortier-Dubois 2024-05-27 15:18:17 -04:00 committed by GitHub
parent fd54a8b470
commit e61b026918
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 83 additions and 76 deletions

View File

@ -97,7 +97,9 @@ pub(crate) fn codegen_expr(
syn::Expr::Break(_) => codegen_break(),
syn::Expr::Return(return_expr) => codegen_return(return_expr),
syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_analyses),
syn::Expr::MethodCall(call) => codegen_expr_method_call(call),
syn::Expr::MethodCall(call) => {
codegen_expr_method_call(call, loop_level, variable_analyses)
}
syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses),
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses),
syn::Expr::Array(array) => codegen_array_lit(array),

View File

@ -1,12 +1,27 @@
use proc_macro2::TokenStream;
use quote::quote_spanned;
use syn::{spanned::Spanned, AngleBracketedGenericArguments, Ident, PathArguments};
use syn::{
punctuated::Punctuated, spanned::Spanned, AngleBracketedGenericArguments, Expr, Ident,
PathArguments, Token,
};
use crate::{analysis::CodeAnalysis, codegen::base::codegen_expr};
/// Codegen for method call
pub(crate) fn codegen_expr_method_call(call: &syn::ExprMethodCall) -> TokenStream {
quote::quote!( #call )
/// Supports [expr].method(args)
pub(crate) fn codegen_expr_method_call(
call: &syn::ExprMethodCall,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
let receiver = codegen_expr(&call.receiver, loop_level, variable_analyses);
let method_expand = syn::Ident::new(
format!("{}_expand", call.method).as_str(),
proc_macro2::Span::call_site(),
);
let args = codegen_args(&call.args, loop_level, variable_analyses);
quote::quote!( #receiver . #method_expand ( #args ))
}
/// Codegen for a closure
@ -110,12 +125,7 @@ pub(crate) fn parse_function_call(
quote::quote! {#code}
}
"unwrap_or_else" => {
let mut args = quote::quote! {};
args.extend(quote::quote! { context, });
for argument in call.args.iter() {
let arg = codegen_expr(argument, loop_level, variable_analyses);
args.extend(quote::quote! { #arg, });
}
let args = codegen_args(&call.args, loop_level, variable_analyses);
// Codegen
quote::quote! {
@ -131,12 +141,7 @@ pub(crate) fn parse_function_call(
(tokens, true)
} else {
let mut args = quote::quote! {};
args.extend(quote::quote! { context, });
for argument in call.args.iter() {
let arg = codegen_expr(argument, loop_level, variable_analyses);
args.extend(quote::quote! { #arg, });
}
let args = codegen_args(&call.args, loop_level, variable_analyses);
// Codegen
let tokens = quote::quote! {
@ -146,3 +151,17 @@ pub(crate) fn parse_function_call(
(tokens, false)
}
}
fn codegen_args(
args: &Punctuated<Expr, Token![,]>,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
let mut arg_tokens = quote::quote! {};
arg_tokens.extend(quote::quote! { context, });
for argument in args.iter() {
let arg_token = codegen_expr(argument, loop_level, variable_analyses);
arg_tokens.extend(quote::quote! { #arg_token, });
}
arg_tokens
}

View File

@ -47,56 +47,49 @@ impl<'a, R: Runtime> ArgSettings<R> for TensorHandle<'a, R> {
impl<T: CubeType> Tensor<T> {
/// Obtain the stride of input at dimension dim
pub fn stride(_input: Tensor<T>, _dim: u32) -> UInt {
pub fn stride(self, _dim: i32) -> UInt {
unexpanded!()
}
/// Obtain the stride of input at dimension dim
pub fn stride_expand(
context: &mut CubeContext,
input: <Tensor<T> as CubeType>::ExpandType,
dim: u32,
) -> ExpandElement {
/// Obtain the shape of input at dimension dim
pub fn shape(self, _dim: i32) -> UInt {
unexpanded!()
}
/// Obtain the array length of input
pub fn len(self) -> UInt {
unexpanded!()
}
}
impl ExpandElement {
// Expanded version of Tensor::stride
pub fn stride_expand(self, context: &mut CubeContext, dim: i32) -> ExpandElement {
let out = context.create_local(Item::new(Elem::UInt));
context.register(Metadata::Stride {
dim: dim.into(),
var: input.into(),
dim: (dim as u32).into(),
var: self.into(),
out: out.clone().into(),
});
out
}
/// Obtain the shape of input at dimension dim
pub fn shape(_input: Tensor<T>, _dim: u32) -> UInt {
unexpanded!()
}
/// Obtain the shape of input at dimension dim
pub fn shape_expand(
context: &mut CubeContext,
input: <Tensor<T> as CubeType>::ExpandType,
dim: u32,
) -> ExpandElement {
// Expanded version of Tensor::shape
pub fn shape_expand(self, context: &mut CubeContext, dim: i32) -> ExpandElement {
let out = context.create_local(Item::new(Elem::UInt));
context.register(Metadata::Shape {
dim: dim.into(),
var: input.into(),
dim: (dim as u32).into(),
var: self.into(),
out: out.clone().into(),
});
out
}
pub fn len(_input: Self) -> UInt {
unexpanded!()
}
pub fn len_expand(
context: &mut CubeContext,
input: <Tensor<T> as CubeType>::ExpandType,
) -> ExpandElement {
// Expanded version of Tensor::len
pub fn len_expand(self, context: &mut CubeContext) -> ExpandElement {
let out = context.create_local(Item::new(Elem::UInt));
context.register(Metadata::ArrayLength {
var: input.into(),
var: self.into(),
out: out.clone().into(),
});
out

View File

@ -2,10 +2,9 @@ use burn_cube::{cube, Numeric, Tensor};
#[cube]
fn kernel<T: Numeric>(input: Tensor<T>) {
// TODO: not the prettiest to be forced to put T even if useless
let _shape = Tensor::<T>::shape(input, 1u32);
let _stride = Tensor::<T>::stride(input, 1u32);
let _length = Tensor::<T>::len(input);
let _shape = input.shape(1);
let _stride = input.stride(1);
let _length = input.len();
}
mod tests {

View File

@ -31,29 +31,23 @@ fn kernel<F: Float>(
kernel_size_0_unroll: Comptime<Option<UInt>>,
kernel_size_1_unroll: Comptime<Option<UInt>>,
) {
if AbsoluteIndex::get() >= Tensor::<F>::len(output) {
if AbsoluteIndex::get() >= output.len() {
return;
}
let in_channels = Tensor::<F>::shape(weight, 1u32);
let in_channels = weight.shape(1);
let kernel_size_0 =
Comptime::unwrap_or_else(kernel_size_0_unroll, || Tensor::<F>::shape(weight, 2u32));
let kernel_size_0 = Comptime::unwrap_or_else(kernel_size_0_unroll, || weight.shape(2));
let unroll_0 = Comptime::is_some(kernel_size_0_unroll);
let kernel_size_1 =
Comptime::unwrap_or_else(kernel_size_1_unroll, || Tensor::<F>::shape(weight, 3u32));
let kernel_size_1 = Comptime::unwrap_or_else(kernel_size_1_unroll, || weight.shape(3));
let unroll_1 = Comptime::is_some(kernel_size_1_unroll);
let b =
AbsoluteIndex::get() / Tensor::<F>::stride(output, 0u32) % Tensor::<F>::shape(output, 0u32);
let oc =
AbsoluteIndex::get() / Tensor::<F>::stride(output, 1u32) % Tensor::<F>::shape(output, 1u32);
let oh =
AbsoluteIndex::get() / Tensor::<F>::stride(output, 2u32) % Tensor::<F>::shape(output, 2u32);
let ow =
AbsoluteIndex::get() / Tensor::<F>::stride(output, 3u32) % Tensor::<F>::shape(output, 3u32);
let b = AbsoluteIndex::get() / output.stride(0) % output.shape(0);
let oc = AbsoluteIndex::get() / output.stride(1) % output.shape(1);
let oh = AbsoluteIndex::get() / output.stride(2) % output.shape(2);
let ow = AbsoluteIndex::get() / output.stride(3) % output.shape(3);
let g = (Tensor::<F>::shape(weight, 0u32) + oc) % groups;
let g = (weight.shape(0) + oc) % groups;
let ic_start = in_channels * g;
let ic_end = ic_start + in_channels;
let mut sum = bias[oc];
@ -61,23 +55,23 @@ fn kernel<F: Float>(
let ih_base = oh * conv_stride_0;
let iw_base = ow * conv_stride_1;
let weight_stride_1 = Tensor::<F>::stride(weight, 1u32);
let weight_stride_2 = Tensor::<F>::stride(weight, 2u32);
let weight_stride_3 = Tensor::<F>::stride(weight, 3u32);
let weight_stride_1 = weight.stride(1);
let weight_stride_2 = weight.stride(2);
let weight_stride_3 = weight.stride(3);
let input_stride_1 = Tensor::<F>::stride(input, 1u32);
let input_stride_2 = Tensor::<F>::stride(input, 2u32);
let input_stride_3 = Tensor::<F>::stride(input, 3u32);
let input_shape_2 = Tensor::<F>::shape(input, 2u32);
let input_shape_3 = Tensor::<F>::shape(input, 3u32);
let input_stride_1 = input.stride(1);
let input_stride_2 = input.stride(2);
let input_stride_3 = input.stride(3);
let input_shape_2 = input.shape(2);
let input_shape_3 = input.shape(3);
let border_top = padding_0;
let border_left = padding_1;
let border_bottom = input_shape_2 + padding_0;
let border_right = input_shape_3 + padding_1;
let index_input_0 = b * Tensor::<F>::stride(input, 0u32);
let index_weight_0 = oc * Tensor::<F>::stride(weight, 0u32);
let index_input_0 = b * input.stride(0);
let index_weight_0 = oc * weight.stride(0);
for ic in range(ic_start, ic_end, Comptime::new(false)) {
let index_input_1 = ic * input_stride_1;