diff --git a/crates/burn-import/src/burn/ty.rs b/crates/burn-import/src/burn/ty.rs index 191333d1a..2612b38e2 100644 --- a/crates/burn-import/src/burn/ty.rs +++ b/crates/burn-import/src/burn/ty.rs @@ -143,6 +143,10 @@ impl TensorType { ); } let formatted_name = Self::format_name(name.as_ref()); + assert_ne!( + dim, 0, + "Trying to create TensorType with dim = 0 - should be a Scalar instead!" + ); Self { name: Ident::new(&formatted_name, Span::call_site()), dim, @@ -151,15 +155,39 @@ impl TensorType { } } pub fn new_float>(name: S, dim: usize) -> Self { - Self::new(name, dim, TensorKind::Float, None) + Self::new_float_with_shape(name, dim, None) + } + + pub fn new_float_with_shape>( + name: S, + dim: usize, + shape: Option>, + ) -> Self { + Self::new(name, dim, TensorKind::Float, shape) } pub fn new_int>(name: S, dim: usize) -> Self { - Self::new(name, dim, TensorKind::Int, None) + Self::new_int_with_shape(name, dim, None) + } + + pub fn new_int_with_shape>( + name: S, + dim: usize, + shape: Option>, + ) -> Self { + Self::new(name, dim, TensorKind::Int, shape) } pub fn new_bool>(name: S, dim: usize) -> Self { - Self::new(name, dim, TensorKind::Bool, None) + Self::new_bool_with_shape(name, dim, None) + } + + pub fn new_bool_with_shape>( + name: S, + dim: usize, + shape: Option>, + ) -> Self { + Self::new(name, dim, TensorKind::Bool, shape) } pub fn ty(&self) -> TokenStream { diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 9635e4155..a77211476 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -424,12 +424,7 @@ impl ParsedOnnxGraph { fn random_uniform_conversion(node: Node) -> RandomUniformNode { let output = node.outputs.first().unwrap(); - // cannot use output.to_tensor_type() here, since it drops the shape info... - let output_type = if let Type::Tensor(t) = Type::from(output) { - t - } else { - panic!("RandomUniform output type is no Tensor."); - }; + let output_type = TensorType::from(output); let high = node .attrs @@ -451,12 +446,7 @@ impl ParsedOnnxGraph { fn random_normal_conversion(node: Node) -> RandomNormalNode { let output = node.outputs.first().unwrap(); - // cannot use output.to_tensor_type() here, since it drops the shape info... - let output_type = if let Type::Tensor(t) = Type::from(output) { - t - } else { - panic!("RandomNormal output type is no Tensor."); - }; + let output_type = TensorType::from(output); let mean = node .attrs @@ -480,11 +470,12 @@ impl ParsedOnnxGraph { // Additional types needed for ConstantOfShape: use crate::burn::node::constant_of_shape::ConstantValue; - let input = node - .inputs - .first() - .expect("ConstantOfShape requires an input tensor"); - let output = node.outputs.first().unwrap(); + let input = Type::from( + node.inputs + .first() + .expect("ConstantOfShape requires an input tensor"), + ); + let output = Type::from(node.outputs.first().unwrap()); // The value of the output elements.Should be a one-element tensor. // If not specified, it defaults to a tensor of value 0 and datatype float32 @@ -504,7 +495,7 @@ impl ParsedOnnxGraph { }) .unwrap_or(ConstantValue::Float32(0.0f32)); - ConstantOfShapeNode::new(Type::from(input), Type::from(output), value) + ConstantOfShapeNode::new(input, output, value) } fn add_conversion(node: Node) -> BinaryNode { @@ -1082,18 +1073,27 @@ impl ParsedOnnxGraph { UnaryNode::exp(input, output) } - fn expand_conversion(node: Node) -> ExpandNode { + fn expand_conversion(mut node: Node) -> ExpandNode { let input = TensorType::from(node.inputs.first().unwrap()); - let mut output = TensorType::from(node.outputs.first().unwrap()); let shape = expand_config(&node); - output.dim = match &shape { - ExpandShape::Static(s) => s.len(), - ExpandShape::Runtime(Type::Shape(s)) => s.dim, - ExpandShape::Runtime(Type::Tensor(t)) => t.shape.as_ref().unwrap()[0], - _ => panic!("Invalid ExpandShape {shape:?}!"), - }; - ExpandNode::new(input, output, shape) + // dim_inference left the dim at zero, so it needs to be filled before converting to TensorType: + assert_eq!( + node.outputs.len(), + 1, + "ExpandNode must have exactly 1 output!" + ); + let mut output_arg = node.outputs.pop().unwrap(); + if let ArgType::Tensor(output_arg_tensor) = &mut output_arg.ty { + output_arg_tensor.dim = match &shape { + ExpandShape::Static(s) => s.len(), + ExpandShape::Runtime(Type::Shape(s)) => s.dim, + ExpandShape::Runtime(Type::Tensor(t)) => t.shape.as_ref().unwrap()[0], + _ => panic!("Invalid ExpandShape {shape:?}!"), + }; + } + + ExpandNode::new(input, TensorType::from(&output_arg), shape) } fn neg_conversion(node: Node) -> UnaryNode { @@ -1236,18 +1236,21 @@ impl From<&OnnxArgument> for TensorType { ArgType::Tensor(OnnxTensorType { elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64, dim, + shape, .. - }) => TensorType::new_float(arg.name.clone(), *dim), + }) => TensorType::new_float_with_shape(arg.name.clone(), *dim, shape.clone()), ArgType::Tensor(OnnxTensorType { elem_type: ElementType::Int32 | ElementType::Int64, dim, + shape, .. - }) => TensorType::new_int(arg.name.clone(), *dim), + }) => TensorType::new_int_with_shape(arg.name.clone(), *dim, shape.clone()), ArgType::Tensor(OnnxTensorType { elem_type: ElementType::Bool, dim, + shape, .. - }) => TensorType::new_bool(arg.name.clone(), *dim), + }) => TensorType::new_bool_with_shape(arg.name.clone(), *dim, shape.clone()), _ => panic!("Can't transform {:?} to tensor.", arg.ty), } } diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index 97b251e0a..ada2516c0 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -1,6 +1,7 @@ use core::cmp::max; use core::panic; +use log::debug; use protobuf::Enum; use crate::{ @@ -12,7 +13,7 @@ use crate::{ /// Infer the dimension of each output tensor and update them. pub fn dim_inference(node: &mut Node) { match node.node_type { - NodeType::Add => same_as_input(node), + NodeType::Add => same_as_input_broadcast(node), NodeType::ArgMax => argmax_update_outputs(node), NodeType::AveragePool1d => same_as_input(node), NodeType::AveragePool2d => same_as_input(node), @@ -21,12 +22,13 @@ pub fn dim_inference(node: &mut Node) { NodeType::Clip => same_as_input(node), NodeType::Concat => concat_update_outputs(node), NodeType::Constant => constant_update_outputs(node), + NodeType::ConstantOfShape => constant_of_shape_update_output(node), NodeType::Conv1d => conv1d_update_outputs(node), NodeType::Conv2d => conv2d_update_outputs(node), NodeType::Cos => same_as_input(node), - NodeType::Div => same_as_input(node), + NodeType::Div => same_as_input_broadcast(node), NodeType::Dropout => same_as_input(node), - NodeType::Equal => equal_update_outputs(node), + NodeType::Equal => elementwise_comparsion_outputs(node), NodeType::Erf => same_as_input(node), NodeType::Exp => same_as_input(node), NodeType::Expand => expand_update_outputs(node), @@ -34,26 +36,31 @@ pub fn dim_inference(node: &mut Node) { NodeType::Gelu => same_as_input(node), NodeType::Gather => gather_update_outputs(node), NodeType::GatherElements => same_as_input(node), + NodeType::Greater => elementwise_comparsion_outputs(node), + NodeType::GreaterOrEqual => elementwise_comparsion_outputs(node), NodeType::HardSigmoid => same_as_input(node), NodeType::GlobalAveragePool => same_as_input(node), NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node), NodeType::LayerNormalization => same_as_input(node), + NodeType::LeakyRelu => same_as_input(node), + NodeType::Less => elementwise_comparsion_outputs(node), + NodeType::LessOrEqual => elementwise_comparsion_outputs(node), NodeType::Linear => linear_update_outputs(node), NodeType::Log => same_as_input(node), NodeType::LogSoftmax => same_as_input(node), NodeType::MatMul => matmul_update_outputs(node), - NodeType::Min => same_as_input(node), - NodeType::Max => same_as_input(node), + NodeType::Max => same_as_input_broadcast(node), NodeType::MaxPool1d => same_as_input(node), NodeType::MaxPool2d => same_as_input(node), + NodeType::Min => same_as_input_broadcast(node), NodeType::Mul => same_as_input(node), NodeType::Neg => same_as_input(node), NodeType::Not => same_as_input(node), NodeType::Pad => same_as_input(node), - NodeType::Greater => greater_update_outputs(node), - NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node), - NodeType::Less => less_update_outputs(node), - NodeType::LessOrEqual => less_or_equal_update_outputs(node), + NodeType::PRelu => same_as_input_broadcast(node), + NodeType::Pow => same_as_input_broadcast(node), + NodeType::RandomNormal => random_update_output(node), + NodeType::RandomUniform => random_update_output(node), NodeType::Range => range_update_outputs(node), NodeType::Reciprocal => same_as_input(node), NodeType::ReduceMax => reduce_max_update_outputs(node), @@ -70,20 +77,14 @@ pub fn dim_inference(node: &mut Node) { NodeType::Sin => same_as_input(node), NodeType::Slice => same_as_input(node), NodeType::Softmax => same_as_input(node), + NodeType::Squeeze => squeeze_update_output(node), NodeType::Sqrt => same_as_input(node), - NodeType::Sub => sub_update_outputs(node), - NodeType::Sum => same_as_input(node), + NodeType::Sub => same_as_input_broadcast(node), + NodeType::Sum => same_as_input_broadcast(node), NodeType::Tanh => same_as_input(node), NodeType::Transpose => same_as_input(node), NodeType::Unsqueeze => unsqueeze_update_output(node), - NodeType::Pow => same_as_input(node), - NodeType::LeakyRelu => same_as_input(node), - NodeType::PRelu => same_as_input(node), NodeType::Where => where_update_outputs(node), - NodeType::Squeeze => squeeze_update_output(node), - NodeType::RandomUniform => random_update_output(node), - NodeType::RandomNormal => random_update_output(node), - NodeType::ConstantOfShape => constant_of_shape_update_output(node), // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. _ => temporary_pass_through_stub(node), } @@ -108,6 +109,9 @@ fn constant_update_outputs(node: &mut Node) { node.outputs[0].ty = match matched_value { Some(value) => match &value { // The value is stored in an attribute + AttributeValue::Tensor(tensor) if tensor.dim == 0 => { + ArgType::Scalar(tensor.elem_type.clone()) + } AttributeValue::Tensor(tensor) => ArgType::Tensor(TensorType { elem_type: tensor.elem_type.clone(), dim: tensor.dim, @@ -319,54 +323,6 @@ fn reshape_update_outputs(node: &mut Node) { } } -fn greater_update_outputs(node: &mut Node) { - match &node.inputs[0].ty { - ArgType::Tensor(tensor) => { - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Bool, - ..tensor.clone() - }); - } - _ => panic!("Only tensor input is valid"), - } -} - -fn less_update_outputs(node: &mut Node) { - match &node.inputs[0].ty { - ArgType::Tensor(tensor) => { - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Bool, - ..tensor.clone() - }); - } - _ => panic!("Only tensor input is valid"), - } -} - -fn greater_or_equal_update_outputs(node: &mut Node) { - match &node.inputs[0].ty { - ArgType::Tensor(tensor) => { - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Bool, - ..tensor.clone() - }); - } - _ => panic!("Only tensor input is valid"), - } -} - -fn less_or_equal_update_outputs(node: &mut Node) { - match &node.inputs[0].ty { - ArgType::Tensor(tensor) => { - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: ElementType::Bool, - ..tensor.clone() - }); - } - _ => panic!("Only tensor input is valid"), - } -} - fn reduce_mean_update_outputs(node: &mut Node) { if node.inputs.len() != 1 { panic!("Mean: multiple inputs are not supported"); @@ -455,18 +411,27 @@ fn squeeze_update_output(node: &mut Node) { }); } -fn sub_update_outputs(node: &mut Node) { - node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) { - (ArgType::Scalar(_lhs), ArgType::Scalar(rhs)) => ArgType::Scalar(rhs), - (ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs), - (ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs), - // Support broadcasting for lhs/rhs - (ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim > rhs.dim => ArgType::Tensor(lhs), - (ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim <= rhs.dim => ArgType::Tensor(rhs), - _ => { - panic!("Only tensor-scalar inputs are valid."); - } - }; +/// Updates the output type for operations that take more than one Tensor/Scalar of the a type +/// and returns the same type while supporting broadcasting +/// +/// This is mostly the elementwise math operations, i.e. add, sub, mul, max etc. +fn same_as_input_broadcast(node: &mut Node) { + if node.inputs.iter().all(|input| input.ty.is_scalar()) { + // If all inputs are scalar, the output is a scalar as well + node.outputs[0].ty = node.inputs[0].ty.clone(); + } else { + // else, if any input is a Tensor, use it's datatype, + // or the input consists solely from Scalars and Shapes, + // which should result in another Shape + node.outputs[0].ty = node + .inputs + .iter() + .find(|input| input.ty.is_tensor()) + .map(|input| input.ty.clone()) + .unwrap_or_else(|| ArgType::Shape(0)); //Shape dim will be set by broadcast calculation + + set_broadcasting_output_shape(node); + } } /// Update the output tensor dimension based on the "axes" attribute or the second input @@ -519,21 +484,34 @@ fn temporary_pass_through_stub(node: &mut Node) { node.outputs[0].ty = node.inputs[0].ty.clone(); } -fn equal_update_outputs(node: &mut Node) { - let input1_type = node.inputs[0].ty.clone(); +/// Sets the output for binary operators resulting in a boolean output, +/// i.e., comparison operators like Equal, Greater, Less, etc. +/// +/// Support for broadcasting is assumed +fn elementwise_comparsion_outputs(node: &mut Node) { + let input1_type = &node.inputs[0].ty; + let input2_type = &node.inputs[1].ty; - match input1_type { - ArgType::Tensor(tensor) => { - // if the input is a tensor, the output is a tensor of bool + match (input1_type, input2_type) { + (ArgType::Tensor(tensor), _) | (_, ArgType::Tensor(tensor)) => { + // if one of the inputs is a tensor, the output is a tensor of bool + assert_ne!( + tensor.dim, 0, + "Got a rank 0 Tensor, that should have been a Scalar!" + ); node.outputs[0].ty = ArgType::Tensor(TensorType { elem_type: ElementType::Bool, - ..tensor + ..tensor.clone() }); + set_broadcasting_output_shape(node); } - ArgType::Scalar(_) => { + (ArgType::Scalar(_), ArgType::Scalar(_)) => { + // if both inputs are scalars, the result is a scalar bool node.outputs[0].ty = ArgType::Scalar(ElementType::Bool); } - _ => panic!("Only tensor input is valid"), + _ => { + panic!("Invalid input types for comparison op: {input1_type:?}, {input2_type:?}") + } } } @@ -788,11 +766,7 @@ fn reduce_sum_update_outputs(node: &mut Node) { } fn where_update_outputs(node: &mut Node) { - match ( - node.inputs[0].ty.clone(), - node.inputs[1].ty.clone(), - node.inputs[2].ty.clone(), - ) { + match (&node.inputs[0].ty, &node.inputs[1].ty, &node.inputs[2].ty) { (ArgType::Tensor(condition), ArgType::Tensor(x), ArgType::Tensor(y)) => { // With broadcasting support, output dim has to be computed based on the inputs node.outputs[0].ty = ArgType::Tensor(TensorType { @@ -800,6 +774,7 @@ fn where_update_outputs(node: &mut Node) { dim: max(condition.dim, max(x.dim, y.dim)), ..Default::default() }); + set_broadcasting_output_shape(node); } _ => panic!("Only tensor input is valid"), } @@ -841,3 +816,68 @@ fn gather_update_outputs(node: &mut Node) { ty => panic!("Only tensor/shape input is valid but received: {:?}", ty), } } + +/// If all input shapes are known, +/// calculates the rank and shape of the output tensor for Operators supporting +/// [broadcasting](https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md) +fn set_broadcasting_output_shape(node: &mut Node) { + let mut reverse_out_shape: Vec = vec![1]; + for (idx, input_type) in node.inputs.iter().enumerate() { + match &input_type.ty { + ArgType::Tensor(t) => { + if let Some(shape) = &t.shape { + for (rev_idx, dimension) in shape.iter().rev().enumerate() { + if let Some(current_out_dim) = reverse_out_shape.get_mut(rev_idx) { + if *dimension == 1 { + // dimension already has a value, this tensor can be broadcast + continue; + } + if current_out_dim != dimension && *current_out_dim != 1 { + panic!("Invalid shape for broadcasting - the dimension from the {rev_idx}. to last position has conflicting values {current_out_dim} and {dimension} from different inputs"); + } + *current_out_dim = *dimension; + } else { + reverse_out_shape.push(*dimension); + } + } + } else { + debug!("Input {idx} has no known shape, cannot predict broadcast result shape"); + return; + } + } + ArgType::Scalar(_) => { + // reverse_out_shape already starts with [1] + } + ArgType::Shape(s) => { + // Shape is treated like a 1-D Tensor + let current_out_dim = &mut reverse_out_shape[0]; + if *current_out_dim != 1 && *current_out_dim != *s { + panic!("Invalid shape for broadcasting - the last position has conflicting values {current_out_dim} and {} from different inputs", s); + } + *current_out_dim = *s; + } + } + } + + // If we get to this point without returning, reverse_out_shape will be final + let mut out_shape = reverse_out_shape; + out_shape.reverse(); //unreverse it + + match &mut node.outputs[0].ty { + ArgType::Tensor(t) => { + t.dim = out_shape.len(); + t.shape = Some(out_shape); + } + ArgType::Scalar(_) => { + if out_shape.len() > 1 || out_shape[0] > 1 { + panic!("Output is a Scalar, but broadcasting results in tensor shape {out_shape:?}") + } + } + ArgType::Shape(s) => { + if out_shape.len() > 1 { + panic!("Output is a Shape, but broadcasting results in higher-rank tensor shape {out_shape:?}") + } + *s = out_shape[0]; + } + } +} diff --git a/crates/onnx-ir/src/ir.rs b/crates/onnx-ir/src/ir.rs index 52f9ee21e..b0abed195 100644 --- a/crates/onnx-ir/src/ir.rs +++ b/crates/onnx-ir/src/ir.rs @@ -128,6 +128,15 @@ impl Default for ArgType { } } +impl ArgType { + pub fn is_scalar(&self) -> bool { + matches!(self, Self::Scalar(_)) + } + pub fn is_tensor(&self) -> bool { + matches!(self, Self::Tensor(_)) + } +} + impl Argument { pub fn new(name: String) -> Self { Self { diff --git a/crates/onnx-ir/src/proto_conversion.rs b/crates/onnx-ir/src/proto_conversion.rs index 43adb76e4..45901f6f4 100644 --- a/crates/onnx-ir/src/proto_conversion.rs +++ b/crates/onnx-ir/src/proto_conversion.rs @@ -242,20 +242,26 @@ impl TryFrom for Argument { } }; - let tensor_type = TensorType { - dim: tensor_proto.shape.dim.len(), - elem_type, - shape: Some( - tensor_proto - .shape - .dim - .iter() - .map(|x| x.dim_value() as Dim) - .collect(), - ), - }; + let ty = if tensor_proto.shape.dim.is_empty() { + // tensor_proto describes a scalar + ArgType::Scalar(elem_type) + } else { + // tensor_proto describes a tensor + let tensor_type = TensorType { + dim: tensor_proto.shape.dim.len(), + elem_type, + shape: Some( + tensor_proto + .shape + .dim + .iter() + .map(|x| x.dim_value() as Dim) + .collect(), + ), + }; - let ty = ArgType::Tensor(tensor_type); + ArgType::Tensor(tensor_type) + }; Ok(Argument { ty,