Improve ONNX import tensor shape tracking (#2213)

- Calculate result of broadcasting in dim_inference
- keep Shape info when converting from Argument to TensorType
- Remove a few sources of Dim = 0 Tensors, create Scalars instead
- Clean up dim_inference a bit
This commit is contained in:
Adrian Müller 2024-08-29 20:06:30 +02:00 committed by GitHub
parent 2f4c5ac0a1
commit e8ea9e27c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 223 additions and 137 deletions

View File

@ -143,6 +143,10 @@ impl TensorType {
);
}
let formatted_name = Self::format_name(name.as_ref());
assert_ne!(
dim, 0,
"Trying to create TensorType with dim = 0 - should be a Scalar instead!"
);
Self {
name: Ident::new(&formatted_name, Span::call_site()),
dim,
@ -151,15 +155,39 @@ impl TensorType {
}
}
pub fn new_float<S: AsRef<str>>(name: S, dim: usize) -> Self {
Self::new(name, dim, TensorKind::Float, None)
Self::new_float_with_shape(name, dim, None)
}
pub fn new_float_with_shape<S: AsRef<str>>(
name: S,
dim: usize,
shape: Option<Vec<usize>>,
) -> Self {
Self::new(name, dim, TensorKind::Float, shape)
}
pub fn new_int<S: AsRef<str>>(name: S, dim: usize) -> Self {
Self::new(name, dim, TensorKind::Int, None)
Self::new_int_with_shape(name, dim, None)
}
pub fn new_int_with_shape<S: AsRef<str>>(
name: S,
dim: usize,
shape: Option<Vec<usize>>,
) -> Self {
Self::new(name, dim, TensorKind::Int, shape)
}
pub fn new_bool<S: AsRef<str>>(name: S, dim: usize) -> Self {
Self::new(name, dim, TensorKind::Bool, None)
Self::new_bool_with_shape(name, dim, None)
}
pub fn new_bool_with_shape<S: AsRef<str>>(
name: S,
dim: usize,
shape: Option<Vec<usize>>,
) -> Self {
Self::new(name, dim, TensorKind::Bool, shape)
}
pub fn ty(&self) -> TokenStream {

View File

@ -424,12 +424,7 @@ impl ParsedOnnxGraph {
fn random_uniform_conversion(node: Node) -> RandomUniformNode {
let output = node.outputs.first().unwrap();
// cannot use output.to_tensor_type() here, since it drops the shape info...
let output_type = if let Type::Tensor(t) = Type::from(output) {
t
} else {
panic!("RandomUniform output type is no Tensor.");
};
let output_type = TensorType::from(output);
let high = node
.attrs
@ -451,12 +446,7 @@ impl ParsedOnnxGraph {
fn random_normal_conversion(node: Node) -> RandomNormalNode {
let output = node.outputs.first().unwrap();
// cannot use output.to_tensor_type() here, since it drops the shape info...
let output_type = if let Type::Tensor(t) = Type::from(output) {
t
} else {
panic!("RandomNormal output type is no Tensor.");
};
let output_type = TensorType::from(output);
let mean = node
.attrs
@ -480,11 +470,12 @@ impl ParsedOnnxGraph {
// Additional types needed for ConstantOfShape:
use crate::burn::node::constant_of_shape::ConstantValue;
let input = node
.inputs
let input = Type::from(
node.inputs
.first()
.expect("ConstantOfShape requires an input tensor");
let output = node.outputs.first().unwrap();
.expect("ConstantOfShape requires an input tensor"),
);
let output = Type::from(node.outputs.first().unwrap());
// The value of the output elements.Should be a one-element tensor.
// If not specified, it defaults to a tensor of value 0 and datatype float32
@ -504,7 +495,7 @@ impl ParsedOnnxGraph {
})
.unwrap_or(ConstantValue::Float32(0.0f32));
ConstantOfShapeNode::new(Type::from(input), Type::from(output), value)
ConstantOfShapeNode::new(input, output, value)
}
fn add_conversion(node: Node) -> BinaryNode {
@ -1082,18 +1073,27 @@ impl ParsedOnnxGraph {
UnaryNode::exp(input, output)
}
fn expand_conversion(node: Node) -> ExpandNode {
fn expand_conversion(mut node: Node) -> ExpandNode {
let input = TensorType::from(node.inputs.first().unwrap());
let mut output = TensorType::from(node.outputs.first().unwrap());
let shape = expand_config(&node);
output.dim = match &shape {
// dim_inference left the dim at zero, so it needs to be filled before converting to TensorType:
assert_eq!(
node.outputs.len(),
1,
"ExpandNode must have exactly 1 output!"
);
let mut output_arg = node.outputs.pop().unwrap();
if let ArgType::Tensor(output_arg_tensor) = &mut output_arg.ty {
output_arg_tensor.dim = match &shape {
ExpandShape::Static(s) => s.len(),
ExpandShape::Runtime(Type::Shape(s)) => s.dim,
ExpandShape::Runtime(Type::Tensor(t)) => t.shape.as_ref().unwrap()[0],
_ => panic!("Invalid ExpandShape {shape:?}!"),
};
}
ExpandNode::new(input, output, shape)
ExpandNode::new(input, TensorType::from(&output_arg), shape)
}
fn neg_conversion(node: Node) -> UnaryNode {
@ -1236,18 +1236,21 @@ impl From<&OnnxArgument> for TensorType {
ArgType::Tensor(OnnxTensorType {
elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64,
dim,
shape,
..
}) => TensorType::new_float(arg.name.clone(), *dim),
}) => TensorType::new_float_with_shape(arg.name.clone(), *dim, shape.clone()),
ArgType::Tensor(OnnxTensorType {
elem_type: ElementType::Int32 | ElementType::Int64,
dim,
shape,
..
}) => TensorType::new_int(arg.name.clone(), *dim),
}) => TensorType::new_int_with_shape(arg.name.clone(), *dim, shape.clone()),
ArgType::Tensor(OnnxTensorType {
elem_type: ElementType::Bool,
dim,
shape,
..
}) => TensorType::new_bool(arg.name.clone(), *dim),
}) => TensorType::new_bool_with_shape(arg.name.clone(), *dim, shape.clone()),
_ => panic!("Can't transform {:?} to tensor.", arg.ty),
}
}

View File

@ -1,6 +1,7 @@
use core::cmp::max;
use core::panic;
use log::debug;
use protobuf::Enum;
use crate::{
@ -12,7 +13,7 @@ use crate::{
/// Infer the dimension of each output tensor and update them.
pub fn dim_inference(node: &mut Node) {
match node.node_type {
NodeType::Add => same_as_input(node),
NodeType::Add => same_as_input_broadcast(node),
NodeType::ArgMax => argmax_update_outputs(node),
NodeType::AveragePool1d => same_as_input(node),
NodeType::AveragePool2d => same_as_input(node),
@ -21,12 +22,13 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Clip => same_as_input(node),
NodeType::Concat => concat_update_outputs(node),
NodeType::Constant => constant_update_outputs(node),
NodeType::ConstantOfShape => constant_of_shape_update_output(node),
NodeType::Conv1d => conv1d_update_outputs(node),
NodeType::Conv2d => conv2d_update_outputs(node),
NodeType::Cos => same_as_input(node),
NodeType::Div => same_as_input(node),
NodeType::Div => same_as_input_broadcast(node),
NodeType::Dropout => same_as_input(node),
NodeType::Equal => equal_update_outputs(node),
NodeType::Equal => elementwise_comparsion_outputs(node),
NodeType::Erf => same_as_input(node),
NodeType::Exp => same_as_input(node),
NodeType::Expand => expand_update_outputs(node),
@ -34,26 +36,31 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Gelu => same_as_input(node),
NodeType::Gather => gather_update_outputs(node),
NodeType::GatherElements => same_as_input(node),
NodeType::Greater => elementwise_comparsion_outputs(node),
NodeType::GreaterOrEqual => elementwise_comparsion_outputs(node),
NodeType::HardSigmoid => same_as_input(node),
NodeType::GlobalAveragePool => same_as_input(node),
NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node),
NodeType::LayerNormalization => same_as_input(node),
NodeType::LeakyRelu => same_as_input(node),
NodeType::Less => elementwise_comparsion_outputs(node),
NodeType::LessOrEqual => elementwise_comparsion_outputs(node),
NodeType::Linear => linear_update_outputs(node),
NodeType::Log => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),
NodeType::MatMul => matmul_update_outputs(node),
NodeType::Min => same_as_input(node),
NodeType::Max => same_as_input(node),
NodeType::Max => same_as_input_broadcast(node),
NodeType::MaxPool1d => same_as_input(node),
NodeType::MaxPool2d => same_as_input(node),
NodeType::Min => same_as_input_broadcast(node),
NodeType::Mul => same_as_input(node),
NodeType::Neg => same_as_input(node),
NodeType::Not => same_as_input(node),
NodeType::Pad => same_as_input(node),
NodeType::Greater => greater_update_outputs(node),
NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node),
NodeType::Less => less_update_outputs(node),
NodeType::LessOrEqual => less_or_equal_update_outputs(node),
NodeType::PRelu => same_as_input_broadcast(node),
NodeType::Pow => same_as_input_broadcast(node),
NodeType::RandomNormal => random_update_output(node),
NodeType::RandomUniform => random_update_output(node),
NodeType::Range => range_update_outputs(node),
NodeType::Reciprocal => same_as_input(node),
NodeType::ReduceMax => reduce_max_update_outputs(node),
@ -70,20 +77,14 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Sin => same_as_input(node),
NodeType::Slice => same_as_input(node),
NodeType::Softmax => same_as_input(node),
NodeType::Squeeze => squeeze_update_output(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Sub => sub_update_outputs(node),
NodeType::Sum => same_as_input(node),
NodeType::Sub => same_as_input_broadcast(node),
NodeType::Sum => same_as_input_broadcast(node),
NodeType::Tanh => same_as_input(node),
NodeType::Transpose => same_as_input(node),
NodeType::Unsqueeze => unsqueeze_update_output(node),
NodeType::Pow => same_as_input(node),
NodeType::LeakyRelu => same_as_input(node),
NodeType::PRelu => same_as_input(node),
NodeType::Where => where_update_outputs(node),
NodeType::Squeeze => squeeze_update_output(node),
NodeType::RandomUniform => random_update_output(node),
NodeType::RandomNormal => random_update_output(node),
NodeType::ConstantOfShape => constant_of_shape_update_output(node),
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
_ => temporary_pass_through_stub(node),
}
@ -108,6 +109,9 @@ fn constant_update_outputs(node: &mut Node) {
node.outputs[0].ty = match matched_value {
Some(value) => match &value {
// The value is stored in an attribute
AttributeValue::Tensor(tensor) if tensor.dim == 0 => {
ArgType::Scalar(tensor.elem_type.clone())
}
AttributeValue::Tensor(tensor) => ArgType::Tensor(TensorType {
elem_type: tensor.elem_type.clone(),
dim: tensor.dim,
@ -319,54 +323,6 @@ fn reshape_update_outputs(node: &mut Node) {
}
}
fn greater_update_outputs(node: &mut Node) {
match &node.inputs[0].ty {
ArgType::Tensor(tensor) => {
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Bool,
..tensor.clone()
});
}
_ => panic!("Only tensor input is valid"),
}
}
fn less_update_outputs(node: &mut Node) {
match &node.inputs[0].ty {
ArgType::Tensor(tensor) => {
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Bool,
..tensor.clone()
});
}
_ => panic!("Only tensor input is valid"),
}
}
fn greater_or_equal_update_outputs(node: &mut Node) {
match &node.inputs[0].ty {
ArgType::Tensor(tensor) => {
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Bool,
..tensor.clone()
});
}
_ => panic!("Only tensor input is valid"),
}
}
fn less_or_equal_update_outputs(node: &mut Node) {
match &node.inputs[0].ty {
ArgType::Tensor(tensor) => {
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Bool,
..tensor.clone()
});
}
_ => panic!("Only tensor input is valid"),
}
}
fn reduce_mean_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("Mean: multiple inputs are not supported");
@ -455,18 +411,27 @@ fn squeeze_update_output(node: &mut Node) {
});
}
fn sub_update_outputs(node: &mut Node) {
node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) {
(ArgType::Scalar(_lhs), ArgType::Scalar(rhs)) => ArgType::Scalar(rhs),
(ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs),
(ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs),
// Support broadcasting for lhs/rhs
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim > rhs.dim => ArgType::Tensor(lhs),
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim <= rhs.dim => ArgType::Tensor(rhs),
_ => {
panic!("Only tensor-scalar inputs are valid.");
/// Updates the output type for operations that take more than one Tensor/Scalar of the a type
/// and returns the same type while supporting broadcasting
///
/// This is mostly the elementwise math operations, i.e. add, sub, mul, max etc.
fn same_as_input_broadcast(node: &mut Node) {
if node.inputs.iter().all(|input| input.ty.is_scalar()) {
// If all inputs are scalar, the output is a scalar as well
node.outputs[0].ty = node.inputs[0].ty.clone();
} else {
// else, if any input is a Tensor, use it's datatype,
// or the input consists solely from Scalars and Shapes,
// which should result in another Shape
node.outputs[0].ty = node
.inputs
.iter()
.find(|input| input.ty.is_tensor())
.map(|input| input.ty.clone())
.unwrap_or_else(|| ArgType::Shape(0)); //Shape dim will be set by broadcast calculation
set_broadcasting_output_shape(node);
}
};
}
/// Update the output tensor dimension based on the "axes" attribute or the second input
@ -519,21 +484,34 @@ fn temporary_pass_through_stub(node: &mut Node) {
node.outputs[0].ty = node.inputs[0].ty.clone();
}
fn equal_update_outputs(node: &mut Node) {
let input1_type = node.inputs[0].ty.clone();
/// Sets the output for binary operators resulting in a boolean output,
/// i.e., comparison operators like Equal, Greater, Less, etc.
///
/// Support for broadcasting is assumed
fn elementwise_comparsion_outputs(node: &mut Node) {
let input1_type = &node.inputs[0].ty;
let input2_type = &node.inputs[1].ty;
match input1_type {
ArgType::Tensor(tensor) => {
// if the input is a tensor, the output is a tensor of bool
match (input1_type, input2_type) {
(ArgType::Tensor(tensor), _) | (_, ArgType::Tensor(tensor)) => {
// if one of the inputs is a tensor, the output is a tensor of bool
assert_ne!(
tensor.dim, 0,
"Got a rank 0 Tensor, that should have been a Scalar!"
);
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Bool,
..tensor
..tensor.clone()
});
set_broadcasting_output_shape(node);
}
ArgType::Scalar(_) => {
(ArgType::Scalar(_), ArgType::Scalar(_)) => {
// if both inputs are scalars, the result is a scalar bool
node.outputs[0].ty = ArgType::Scalar(ElementType::Bool);
}
_ => panic!("Only tensor input is valid"),
_ => {
panic!("Invalid input types for comparison op: {input1_type:?}, {input2_type:?}")
}
}
}
@ -788,11 +766,7 @@ fn reduce_sum_update_outputs(node: &mut Node) {
}
fn where_update_outputs(node: &mut Node) {
match (
node.inputs[0].ty.clone(),
node.inputs[1].ty.clone(),
node.inputs[2].ty.clone(),
) {
match (&node.inputs[0].ty, &node.inputs[1].ty, &node.inputs[2].ty) {
(ArgType::Tensor(condition), ArgType::Tensor(x), ArgType::Tensor(y)) => {
// With broadcasting support, output dim has to be computed based on the inputs
node.outputs[0].ty = ArgType::Tensor(TensorType {
@ -800,6 +774,7 @@ fn where_update_outputs(node: &mut Node) {
dim: max(condition.dim, max(x.dim, y.dim)),
..Default::default()
});
set_broadcasting_output_shape(node);
}
_ => panic!("Only tensor input is valid"),
}
@ -841,3 +816,68 @@ fn gather_update_outputs(node: &mut Node) {
ty => panic!("Only tensor/shape input is valid but received: {:?}", ty),
}
}
/// If all input shapes are known,
/// calculates the rank and shape of the output tensor for Operators supporting
/// [broadcasting](https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md)
fn set_broadcasting_output_shape(node: &mut Node) {
let mut reverse_out_shape: Vec<usize> = vec![1];
for (idx, input_type) in node.inputs.iter().enumerate() {
match &input_type.ty {
ArgType::Tensor(t) => {
if let Some(shape) = &t.shape {
for (rev_idx, dimension) in shape.iter().rev().enumerate() {
if let Some(current_out_dim) = reverse_out_shape.get_mut(rev_idx) {
if *dimension == 1 {
// dimension already has a value, this tensor can be broadcast
continue;
}
if current_out_dim != dimension && *current_out_dim != 1 {
panic!("Invalid shape for broadcasting - the dimension from the {rev_idx}. to last position has conflicting values {current_out_dim} and {dimension} from different inputs");
}
*current_out_dim = *dimension;
} else {
reverse_out_shape.push(*dimension);
}
}
} else {
debug!("Input {idx} has no known shape, cannot predict broadcast result shape");
return;
}
}
ArgType::Scalar(_) => {
// reverse_out_shape already starts with [1]
}
ArgType::Shape(s) => {
// Shape is treated like a 1-D Tensor
let current_out_dim = &mut reverse_out_shape[0];
if *current_out_dim != 1 && *current_out_dim != *s {
panic!("Invalid shape for broadcasting - the last position has conflicting values {current_out_dim} and {} from different inputs", s);
}
*current_out_dim = *s;
}
}
}
// If we get to this point without returning, reverse_out_shape will be final
let mut out_shape = reverse_out_shape;
out_shape.reverse(); //unreverse it
match &mut node.outputs[0].ty {
ArgType::Tensor(t) => {
t.dim = out_shape.len();
t.shape = Some(out_shape);
}
ArgType::Scalar(_) => {
if out_shape.len() > 1 || out_shape[0] > 1 {
panic!("Output is a Scalar, but broadcasting results in tensor shape {out_shape:?}")
}
}
ArgType::Shape(s) => {
if out_shape.len() > 1 {
panic!("Output is a Shape, but broadcasting results in higher-rank tensor shape {out_shape:?}")
}
*s = out_shape[0];
}
}
}

View File

@ -128,6 +128,15 @@ impl Default for ArgType {
}
}
impl ArgType {
pub fn is_scalar(&self) -> bool {
matches!(self, Self::Scalar(_))
}
pub fn is_tensor(&self) -> bool {
matches!(self, Self::Tensor(_))
}
}
impl Argument {
pub fn new(name: String) -> Self {
Self {

View File

@ -242,6 +242,11 @@ impl TryFrom<ValueInfoProto> for Argument {
}
};
let ty = if tensor_proto.shape.dim.is_empty() {
// tensor_proto describes a scalar
ArgType::Scalar(elem_type)
} else {
// tensor_proto describes a tensor
let tensor_type = TensorType {
dim: tensor_proto.shape.dim.len(),
elem_type,
@ -255,7 +260,8 @@ impl TryFrom<ValueInfoProto> for Argument {
),
};
let ty = ArgType::Tensor(tensor_type);
ArgType::Tensor(tensor_type)
};
Ok(Argument {
ty,