From 6b51b73a5f8411332f90d1c60e4b8f88de0fe3db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20M=C3=BCller?= Date: Tue, 3 Sep 2024 17:17:18 +0200 Subject: [PATCH] Fix ONNX where op for scalar inputs (#2218) * Fix ONNX where op dim_inference for scalar inputs * Rewrite ONNX Where codegen to support scalars * ONNX Where: Add tests for all_scalar inputs --------- Co-authored-by: Guillaume Lagrange --- crates/burn-import/onnx-tests/build.rs | 4 + .../tests/mask_where/mask_where.onnx | 22 +- .../onnx-tests/tests/mask_where/mask_where.py | 42 ++- .../mask_where/mask_where_all_scalar.onnx | Bin 0 -> 190 bytes .../mask_where/mask_where_broadcast.onnx | 22 ++ .../tests/mask_where/mask_where_scalar_x.onnx | Bin 0 -> 214 bytes .../tests/mask_where/mask_where_scalar_y.onnx | Bin 0 -> 214 bytes .../burn-import/onnx-tests/tests/test_onnx.rs | 74 +++- crates/burn-import/src/burn/node/gather.rs | 10 +- .../burn-import/src/burn/node/mask_where.rs | 317 +++++++++++++++--- crates/burn-import/src/burn/ty.rs | 54 +++ crates/burn-import/src/onnx/to_burn.rs | 8 +- crates/onnx-ir/src/dim_inference.rs | 36 +- crates/onnx-ir/src/ir.rs | 20 +- 14 files changed, 497 insertions(+), 112 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/mask_where/mask_where_all_scalar.onnx create mode 100644 crates/burn-import/onnx-tests/tests/mask_where/mask_where_broadcast.onnx create mode 100644 crates/burn-import/onnx-tests/tests/mask_where/mask_where_scalar_x.onnx create mode 100644 crates/burn-import/onnx-tests/tests/mask_where/mask_where_scalar_y.onnx diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index c130969d8..a7360e012 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -56,6 +56,10 @@ fn main() { .input("tests/log/log.onnx") .input("tests/log_softmax/log_softmax.onnx") .input("tests/mask_where/mask_where.onnx") + .input("tests/mask_where/mask_where_broadcast.onnx") + .input("tests/mask_where/mask_where_scalar_x.onnx") + .input("tests/mask_where/mask_where_scalar_y.onnx") + .input("tests/mask_where/mask_where_all_scalar.onnx") .input("tests/matmul/matmul.onnx") .input("tests/max/max.onnx") .input("tests/maxpool1d/maxpool1d.onnx") diff --git a/crates/burn-import/onnx-tests/tests/mask_where/mask_where.onnx b/crates/burn-import/onnx-tests/tests/mask_where/mask_where.onnx index 0f0fb289a..9f22ab681 100644 --- a/crates/burn-import/onnx-tests/tests/mask_where/mask_where.onnx +++ b/crates/burn-import/onnx-tests/tests/mask_where/mask_where.onnx @@ -1,12 +1,8 @@ -pytorch2.1.2:× +pytorch2.3.0:Å ? onnx::Where_0 onnx::Where_1 - onnx::Where_25/Where"Where -A - onnx::Where_0 - onnx::Where_3 - onnx::Where_46/Where_1"Where + onnx::Where_23/Where"Where main_graphZ onnx::Where_0   @@ -19,20 +15,8 @@ main_graphZ onnx::Where_2   -Z - onnx::Where_3 - - -Z - onnx::Where_4 - - b -5 -  - -b -6 +3   B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/mask_where/mask_where.py b/crates/burn-import/onnx-tests/tests/mask_where/mask_where.py index 5e7f1d9b6..048dc1056 100644 --- a/crates/burn-import/onnx-tests/tests/mask_where/mask_where.py +++ b/crates/burn-import/onnx-tests/tests/mask_where/mask_where.py @@ -1,6 +1,10 @@ #!/usr/bin/env python3 -# used to generate model: onnx-tests/tests/mask_where/mask_where.onnx +# used to generate models: +# mask_where.onnx +# mask_where_broadcast.onnx +# mask_where_scalar_x.onnx +# mask_where_scalar_y.onnx import torch import torch.nn as nn @@ -10,23 +14,17 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, condition, x1, y1, x2, y2): - return torch.where(condition, x1, y1), torch.where(condition, x2, y2) + def forward(self, condition, x, y): + return torch.where(condition, x, y) -def main(): - # Set random seed for reproducibility - torch.manual_seed(0) - +def create_model(name: str, device: torch.device, mask: torch.Tensor, x: torch.Tensor, y: torch.Tensor): + print(f"--- {name} ---") # Export to onnx model = Model() model.eval() - device = torch.device("cpu") - onnx_name = "mask_where.onnx" - x = torch.ones(2, 2, device=device) - y = torch.zeros(2, 2, device=device) - mask = torch.tensor([[True, False], [False, True]], device=device) - test_input = (mask, x, y, x[0], y[0]) + onnx_name = f"{name}.onnx" + test_input = (mask, x, y) torch.onnx.export(model, (test_input), onnx_name, verbose=False, opset_version=16) @@ -37,6 +35,24 @@ def main(): output = model.forward(*test_input) print(f"Test output data: {output}") +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + device = torch.device("cpu") + + mask = torch.tensor([[True, False], [False, True]], device=device) + x = torch.ones(2, 2, device=device) + y = torch.zeros(2, 2, device=device) + mask_scalar = torch.tensor(True, device=device) + x_scalar = torch.tensor(1., device=device) + y_scalar = torch.tensor(0., device=device) + create_model("mask_where", device, mask, x, y) + create_model("mask_where_broadcast", device, mask, x[0], y[0]) + create_model("mask_where_scalar_x", device, mask, x_scalar, y) + create_model("mask_where_scalar_y", device, mask, x, y_scalar) + create_model("mask_where_all_scalar", device, mask_scalar, x_scalar, y_scalar) + + if __name__ == "__main__": main() \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/mask_where/mask_where_all_scalar.onnx b/crates/burn-import/onnx-tests/tests/mask_where/mask_where_all_scalar.onnx new file mode 100644 index 0000000000000000000000000000000000000000..bcaebd84bf3c5637821df58c8e472242d8d6ed5c GIT binary patch literal 190 zcmd = mask_where::Model::new(&device); - let x1 = Tensor::ones([2, 2], &device); - let y1 = Tensor::zeros([2, 2], &device); - let x2 = Tensor::ones([2], &device); - let y2 = Tensor::zeros([2], &device); + let x = Tensor::ones([2, 2], &device); + let y = Tensor::zeros([2, 2], &device); let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device); - let (output, output_broadcasted) = model.forward(mask, x1, y1, x2, y2); + let output = model.forward(mask, x, y); let expected = TensorData::from([[1f32, 0.0], [0.0, 1.0]]); output.to_data().assert_eq(&expected, true); - output_broadcasted.to_data().assert_eq(&expected, true); + } + + #[test] + fn mask_where_broadcast() { + let device = Default::default(); + let model: mask_where_broadcast::Model = mask_where_broadcast::Model::new(&device); + + let x = Tensor::ones([2], &device); + let y = Tensor::zeros([2], &device); + let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device); + + let output = model.forward(mask, x, y); + let expected = TensorData::from([[1f32, 0.0], [0.0, 1.0]]); + + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn mask_where_scalar_x() { + let device = Default::default(); + let model: mask_where_scalar_x::Model = mask_where_scalar_x::Model::new(&device); + + let x = 1.0f32; + let y = Tensor::zeros([2, 2], &device); + let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device); + + let output = model.forward(mask, x, y); + let expected = TensorData::from([[1f32, 0.0], [0.0, 1.0]]); + + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn mask_where_scalar_y() { + let device = Default::default(); + let model: mask_where_scalar_y::Model = mask_where_scalar_y::Model::new(&device); + + let x = Tensor::ones([2, 2], &device); + let y = 0.0f32; + let mask = Tensor::from_bool([[true, false], [false, true]].into(), &device); + + let output = model.forward(mask, x, y); + let expected = TensorData::from([[1f32, 0.0], [0.0, 1.0]]); + + output.to_data().assert_eq(&expected, true); + } + + #[test] + fn mask_where_all_scalar() { + let device = Default::default(); + let model: mask_where_all_scalar::Model = + mask_where_all_scalar::Model::new(&device); + + let x = 1.0f32; + let y = 0.0f32; + let mask = true; + + let output = model.forward(mask, x, y); + let expected = 1.0f32; + + assert_eq!(output, expected); } #[test] diff --git a/crates/burn-import/src/burn/node/gather.rs b/crates/burn-import/src/burn/node/gather.rs index a9b66e1d0..4004be971 100644 --- a/crates/burn-import/src/burn/node/gather.rs +++ b/crates/burn-import/src/burn/node/gather.rs @@ -35,13 +35,7 @@ impl NodeCodegen for GatherNode { let input = match &self.input { Type::Tensor(in_tensor) => scope.tensor_use_owned(in_tensor, node_position), - Type::Shape(in_shape) => { - let in_shape_name = &in_shape.name; - // To copy just the values from the shape value without moving it - // (which could lead to ownership problems if the same Shape is used multiple times) - // borrow the array as a slice and use that to create the Tensor: - quote! { Tensor::::from_data(&#in_shape_name as &[_], &*self.device) } - } + Type::Shape(in_shape) => in_shape.to_tensor(), _ => panic!("Gather needs Scalar or Shape input, got {:?}!", self.input), }; @@ -281,7 +275,7 @@ mod tests { let indices = tensor1; let tensor2 = Tensor::select( - Tensor::::from_data(&shape1 as &[_], &*self.device), + Tensor::::from_data(&shape1 as &[_], &*self.device), 0, indices, ); diff --git a/crates/burn-import/src/burn/node/mask_where.rs b/crates/burn-import/src/burn/node/mask_where.rs index 19b99dc0e..cccc25d52 100644 --- a/crates/burn-import/src/burn/node/mask_where.rs +++ b/crates/burn-import/src/burn/node/mask_where.rs @@ -1,68 +1,64 @@ -use core::cmp::max; - use super::{Node, NodeCodegen}; -use crate::burn::{BurnImports, TensorType, ToTokens, Type}; +use crate::burn::{BurnImports, ScalarType, ToTokens, Type}; use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; use quote::quote; #[derive(Debug, Clone, new)] pub struct WhereNode { /// Bool tensor. When True (nonzero), yield X, otherwise yield Y. - pub condition: TensorType, + pub condition: Type, /// Values selected at indices where condition is True. - pub x: TensorType, + pub x: Type, /// Values selected at indices where condition is False. - pub y: TensorType, - pub output: TensorType, + pub y: Type, + pub output: Type, } impl NodeCodegen for WhereNode { fn output_types(&self) -> Vec { - vec![Type::Tensor(self.output.clone())] + vec![self.output.clone()] } fn input_types(&self) -> Vec { - vec![ - Type::Tensor(self.condition.clone()), - Type::Tensor(self.x.clone()), - Type::Tensor(self.y.clone()), - ] + vec![self.condition.clone(), self.x.clone(), self.y.clone()] } - fn forward( - &self, - scope: &mut crate::burn::Scope, - node_position: usize, - ) -> proc_macro2::TokenStream { - let mut mask = scope.tensor_use_owned(&self.condition, node_position); - let mut x = scope.tensor_use_owned(&self.x, node_position); - let mut y = scope.tensor_use_owned(&self.y, node_position); - let output = &self.output.name; + fn forward(&self, scope: &mut crate::burn::Scope, node_position: usize) -> TokenStream { + match &self.output { + Type::Tensor(out) => { + let cond = Self::input_as_tensor(&self.condition, out.dim, scope, node_position); + let y = Self::input_as_tensor(&self.y, out.dim, scope, node_position); + let out_id = &out.name; - // x, y and condition need to be broadcastable - let broadcasted_dim = max(max(self.x.dim, self.y.dim), self.condition.dim); - let unsqueeze_dims = broadcasted_dim.to_tokens(); - - if self.condition.dim < broadcasted_dim { - mask = quote! { #mask.unsqueeze::<#unsqueeze_dims>()}; - } - - if self.x.dim < broadcasted_dim { - x = quote! { #x.unsqueeze::<#unsqueeze_dims>()}; - } - - if self.y.dim < broadcasted_dim { - y = quote! { #y.unsqueeze::<#unsqueeze_dims>()}; - } - - quote! { - let #output = #y.mask_where(#mask, #x); + if let Type::Scalar(x) = &self.x { + let x = &x.name; + quote! { + let #out_id = #y.mask_fill(#cond, #x); + } + } else { + let x = Self::input_as_tensor(&self.x, out.dim, scope, node_position); + quote! { + let #out_id = #y.mask_where(#cond, #x); + } + } + } + Type::Scalar(out) => { + // Scalar out means all inputs are scalars as well: + let cond = self.condition.as_scalar(); + let x = self.x.as_scalar(); + let y = self.y.as_scalar(); + Self::forward_scalar(out, cond, x, y) + } + other => panic!("Where cannot handle {other:?}"), } } fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::tensor::Bool"); + if matches!(&self.output, Type::Tensor(_)) { + imports.register("burn::tensor::Bool"); + } } fn into_node(self) -> super::Node { @@ -70,6 +66,50 @@ impl NodeCodegen for WhereNode { } } +impl WhereNode { + fn forward_scalar( + out: &ScalarType, + cond: &ScalarType, + x: &ScalarType, + y: &ScalarType, + ) -> TokenStream { + let out_name = &out.name; + let out_type = out.ty(); + let cond_name = &cond.name; + let x_name = &x.name; + let y_name = &y.name; + + quote! { + let #out_name : #out_type = if #cond_name { + #x_name + } + else { + #y_name + }; + } + } + + fn input_as_tensor( + input: &Type, + broadcast_rank: usize, + scope: &mut crate::burn::Scope, + node_position: usize, + ) -> TokenStream { + let (tensor, rank) = match input { + Type::Tensor(t) => (scope.tensor_use_owned(t, node_position), t.dim), + Type::Scalar(s) => (s.to_full_tensor(&vec![1; broadcast_rank]), broadcast_rank), + Type::Shape(s) => (s.to_tensor(), 1), + _ => panic!("Where op: {input:?} input not implemented"), + }; + if rank < broadcast_rank { + let broadcast_rank_tokens = broadcast_rank.to_tokens(); + quote! { #tensor.unsqueeze::<#broadcast_rank_tokens>()} + } else { + tensor + } + } +} + #[cfg(test)] mod tests { @@ -79,7 +119,7 @@ mod tests { use crate::burn::{ graph::BurnGraph, node::{mask_where::WhereNode, test::assert_tokens}, - TensorType, + ScalarKind, TensorType, }; #[test] @@ -87,10 +127,10 @@ mod tests { let mut graph = BurnGraph::::default(); graph.register(WhereNode::new( - TensorType::new_bool("tensor1", 2), - TensorType::new_float("tensor2", 2), - TensorType::new_float("tensor3", 2), - TensorType::new_float("tensor4", 2), + Type::Tensor(TensorType::new_bool("tensor1", 2)), + Type::Tensor(TensorType::new_float("tensor2", 2)), + Type::Tensor(TensorType::new_float("tensor3", 2)), + Type::Tensor(TensorType::new_float("tensor4", 2)), )); graph.register_input_output( @@ -146,10 +186,10 @@ mod tests { let mut graph = BurnGraph::::default(); graph.register(WhereNode::new( - TensorType::new_bool("tensor1", 4), - TensorType::new_float("tensor2", 2), - TensorType::new_float("tensor3", 3), - TensorType::new_float("tensor4", 4), + Type::Tensor(TensorType::new_bool("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 2)), + Type::Tensor(TensorType::new_float("tensor3", 3)), + Type::Tensor(TensorType::new_float("tensor4", 4)), )); graph.register_input_output( @@ -201,4 +241,181 @@ mod tests { assert_tokens(graph.codegen(), expected); } + + #[test] + fn test_codegen_where_scalar_x() { + let mut graph = BurnGraph::::default(); + + graph.register(WhereNode::new( + Type::Tensor(TensorType::new_bool("tensor1", 2)), + Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)), + Type::Tensor(TensorType::new_float("tensor3", 2)), + Type::Tensor(TensorType::new_float("tensor4", 2)), + )); + + graph.register_input_output( + vec![ + "tensor1".to_string(), + "scalar2".to_string(), + "tensor3".to_string(), + ], + vec!["tensor4".to_string()], + ); + + let expected = quote! { + use burn::tensor::Bool; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + scalar2: f64, + tensor3: Tensor + ) -> Tensor { + let tensor4 = tensor3.mask_fill(tensor1, scalar2); + + tensor4 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_where_scalar_y() { + let mut graph = BurnGraph::::default(); + + graph.register(WhereNode::new( + Type::Tensor(TensorType::new_bool("tensor1", 2)), + Type::Tensor(TensorType::new_float("tensor2", 2)), + Type::Scalar(ScalarType::new("scalar3", ScalarKind::Float64)), + Type::Tensor(TensorType::new_float("tensor4", 2)), + )); + + graph.register_input_output( + vec![ + "tensor1".to_string(), + "tensor2".to_string(), + "scalar3".to_string(), + ], + vec!["tensor4".to_string()], + ); + + let expected = quote! { + use burn::tensor::Bool; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + tensor2: Tensor, + scalar3: f64 + ) -> Tensor { + let tensor4 = Tensor::::full([1, 1], scalar3, &*self.device) + .mask_where(tensor1, tensor2); + + tensor4 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_where_all_scalar() { + let mut graph = BurnGraph::::default(); + + graph.register(WhereNode::new( + Type::Scalar(ScalarType::new("scalar1", ScalarKind::Bool)), + Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)), + Type::Scalar(ScalarType::new("scalar3", ScalarKind::Float64)), + Type::Scalar(ScalarType::new("scalar4", ScalarKind::Float64)), + )); + + graph.register_input_output( + vec![ + "scalar1".to_string(), + "scalar2".to_string(), + "scalar3".to_string(), + ], + vec!["scalar4".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + scalar1: bool, + scalar2: f64, + scalar3: f64 + ) -> f64 { + let scalar4: f64 = if scalar1 { scalar2 } else { scalar3 }; + + scalar4 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } } diff --git a/crates/burn-import/src/burn/ty.rs b/crates/burn-import/src/burn/ty.rs index 2612b38e2..1b459ef67 100644 --- a/crates/burn-import/src/burn/ty.rs +++ b/crates/burn-import/src/burn/ty.rs @@ -79,6 +79,27 @@ impl Type { Type::Other(other) => other.ty(), } } + pub fn as_tensor(&self) -> &TensorType { + if let Self::Tensor(t) = self { + t + } else { + panic!("Called Type::as_tensor on {self:?}!"); + } + } + pub fn as_scalar(&self) -> &ScalarType { + if let Self::Scalar(s) = self { + s + } else { + panic!("Called Type::as_scalar on {self:?}!"); + } + } + pub fn as_shape(&self) -> &ShapeType { + if let Self::Shape(s) = self { + s + } else { + panic!("Called Type::as_shape on {self:?}!"); + } + } } impl ScalarType { @@ -100,6 +121,28 @@ impl ScalarType { ScalarKind::Bool => quote! { bool }, } } + + /// Helper for Ops that need to process a Scalar as a tensor on device + /// + /// Uploads the Scalar to the device as a full tensor using the given shape definition + pub fn to_full_tensor(&self, shape: &[usize]) -> TokenStream { + let name = &self.name; + let shape_tokens = shape + .iter() + .map(ToTokens::to_tokens) + .map(|s| quote! {#s, }) + .collect::(); + let rank = shape.len(); + let rank_tokens = rank.to_tokens(); + let tensor_kind = match self.kind { + ScalarKind::Int32 | ScalarKind::Int64 => quote! { burn::tensor::Int }, + ScalarKind::Float32 | ScalarKind::Float64 => quote! { burn::tensor::Float }, + ScalarKind::Bool => quote! { burn::tensor::Bool }, + }; + quote! { + Tensor::::full([#shape_tokens], #name, &*self.device) + } + } } impl ShapeType { @@ -116,6 +159,17 @@ impl ShapeType { let dim = self.dim.to_tokens(); quote! { [usize; #dim] } } + + /// Helper for Ops that need to process a shape as a tensor on device + /// + /// Uploads the Shape to the device as a rank 1 Int tensor + pub fn to_tensor(&self) -> TokenStream { + let shape_name = &self.name; + // To copy just the values from the shape value without moving it + // (which could lead to ownership problems if the same Shape is used multiple times) + // borrow the array as a slice and use that to create the Tensor: + quote! { Tensor::::from_data(&#shape_name as &[_], &*self.device) } + } } impl TensorType { diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index a77211476..5c4f34078 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -747,10 +747,10 @@ impl ParsedOnnxGraph { } fn where_conversion(node: Node) -> WhereNode { - let condition = TensorType::from(node.inputs.first().unwrap()); - let x = TensorType::from(node.inputs.get(1).unwrap()); - let y = TensorType::from(node.inputs.get(2).unwrap()); - let output = TensorType::from(node.outputs.first().unwrap()); + let condition = Type::from(node.inputs.first().unwrap()); + let x = Type::from(node.inputs.get(1).unwrap()); + let y = Type::from(node.inputs.get(2).unwrap()); + let output = Type::from(node.outputs.first().unwrap()); WhereNode::new(condition, x, y, output) } diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index ada2516c0..6275b36ed 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -766,17 +766,31 @@ fn reduce_sum_update_outputs(node: &mut Node) { } fn where_update_outputs(node: &mut Node) { - match (&node.inputs[0].ty, &node.inputs[1].ty, &node.inputs[2].ty) { - (ArgType::Tensor(condition), ArgType::Tensor(x), ArgType::Tensor(y)) => { - // With broadcasting support, output dim has to be computed based on the inputs - node.outputs[0].ty = ArgType::Tensor(TensorType { - elem_type: x.elem_type.clone(), - dim: max(condition.dim, max(x.dim, y.dim)), - ..Default::default() - }); - set_broadcasting_output_shape(node); - } - _ => panic!("Only tensor input is valid"), + let condition = &node.inputs[0].ty; + let x = &node.inputs[1].ty; + let y = &node.inputs[2].ty; + let elem_type = x.elem_type().clone(); + assert_eq!( + *condition.elem_type(), + ElementType::Bool, + "Where condition must be boolean!" + ); + assert_eq!( + elem_type, + *y.elem_type(), + "Where x and y have different element types!" + ); + + let output_rank = max(condition.rank(), max(x.rank(), y.rank())); + if output_rank == 0 { + node.outputs[0].ty = ArgType::Scalar(elem_type); + } else { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + dim: output_rank, + ..Default::default() + }); + set_broadcasting_output_shape(node); } } diff --git a/crates/onnx-ir/src/ir.rs b/crates/onnx-ir/src/ir.rs index b0abed195..937730e28 100644 --- a/crates/onnx-ir/src/ir.rs +++ b/crates/onnx-ir/src/ir.rs @@ -92,7 +92,7 @@ pub enum AttributeValue { pub type Attributes = HashMap; /// The type of an element. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum ElementType { Float32, Float64, @@ -135,6 +135,24 @@ impl ArgType { pub fn is_tensor(&self) -> bool { matches!(self, Self::Tensor(_)) } + + /// returns the rank (dimension) of the Arg + pub fn rank(&self) -> usize { + match self { + ArgType::Scalar(_) => 0, + ArgType::Shape(_) => 1, + ArgType::Tensor(t) => t.dim, + } + } + + /// returns the element type of the Arg + pub fn elem_type(&self) -> &ElementType { + match self { + ArgType::Scalar(s) => s, + ArgType::Shape(_) => panic!("ArgType::Shape has no ElementType"), + ArgType::Tensor(t) => &t.elem_type, + } + } } impl Argument {