mirror of https://github.com/tracel-ai/burn.git
Improve ONNX import tensor shape tracking (#2213)
- Calculate result of broadcasting in dim_inference - keep Shape info when converting from Argument to TensorType - Remove a few sources of Dim = 0 Tensors, create Scalars instead - Clean up dim_inference a bit
This commit is contained in:
parent
2f4c5ac0a1
commit
e8ea9e27c2
|
@ -143,6 +143,10 @@ impl TensorType {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let formatted_name = Self::format_name(name.as_ref());
|
let formatted_name = Self::format_name(name.as_ref());
|
||||||
|
assert_ne!(
|
||||||
|
dim, 0,
|
||||||
|
"Trying to create TensorType with dim = 0 - should be a Scalar instead!"
|
||||||
|
);
|
||||||
Self {
|
Self {
|
||||||
name: Ident::new(&formatted_name, Span::call_site()),
|
name: Ident::new(&formatted_name, Span::call_site()),
|
||||||
dim,
|
dim,
|
||||||
|
@ -151,15 +155,39 @@ impl TensorType {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn new_float<S: AsRef<str>>(name: S, dim: usize) -> Self {
|
pub fn new_float<S: AsRef<str>>(name: S, dim: usize) -> Self {
|
||||||
Self::new(name, dim, TensorKind::Float, None)
|
Self::new_float_with_shape(name, dim, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_float_with_shape<S: AsRef<str>>(
|
||||||
|
name: S,
|
||||||
|
dim: usize,
|
||||||
|
shape: Option<Vec<usize>>,
|
||||||
|
) -> Self {
|
||||||
|
Self::new(name, dim, TensorKind::Float, shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_int<S: AsRef<str>>(name: S, dim: usize) -> Self {
|
pub fn new_int<S: AsRef<str>>(name: S, dim: usize) -> Self {
|
||||||
Self::new(name, dim, TensorKind::Int, None)
|
Self::new_int_with_shape(name, dim, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_int_with_shape<S: AsRef<str>>(
|
||||||
|
name: S,
|
||||||
|
dim: usize,
|
||||||
|
shape: Option<Vec<usize>>,
|
||||||
|
) -> Self {
|
||||||
|
Self::new(name, dim, TensorKind::Int, shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_bool<S: AsRef<str>>(name: S, dim: usize) -> Self {
|
pub fn new_bool<S: AsRef<str>>(name: S, dim: usize) -> Self {
|
||||||
Self::new(name, dim, TensorKind::Bool, None)
|
Self::new_bool_with_shape(name, dim, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_bool_with_shape<S: AsRef<str>>(
|
||||||
|
name: S,
|
||||||
|
dim: usize,
|
||||||
|
shape: Option<Vec<usize>>,
|
||||||
|
) -> Self {
|
||||||
|
Self::new(name, dim, TensorKind::Bool, shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ty(&self) -> TokenStream {
|
pub fn ty(&self) -> TokenStream {
|
||||||
|
|
|
@ -424,12 +424,7 @@ impl ParsedOnnxGraph {
|
||||||
|
|
||||||
fn random_uniform_conversion(node: Node) -> RandomUniformNode {
|
fn random_uniform_conversion(node: Node) -> RandomUniformNode {
|
||||||
let output = node.outputs.first().unwrap();
|
let output = node.outputs.first().unwrap();
|
||||||
// cannot use output.to_tensor_type() here, since it drops the shape info...
|
let output_type = TensorType::from(output);
|
||||||
let output_type = if let Type::Tensor(t) = Type::from(output) {
|
|
||||||
t
|
|
||||||
} else {
|
|
||||||
panic!("RandomUniform output type is no Tensor.");
|
|
||||||
};
|
|
||||||
|
|
||||||
let high = node
|
let high = node
|
||||||
.attrs
|
.attrs
|
||||||
|
@ -451,12 +446,7 @@ impl ParsedOnnxGraph {
|
||||||
|
|
||||||
fn random_normal_conversion(node: Node) -> RandomNormalNode {
|
fn random_normal_conversion(node: Node) -> RandomNormalNode {
|
||||||
let output = node.outputs.first().unwrap();
|
let output = node.outputs.first().unwrap();
|
||||||
// cannot use output.to_tensor_type() here, since it drops the shape info...
|
let output_type = TensorType::from(output);
|
||||||
let output_type = if let Type::Tensor(t) = Type::from(output) {
|
|
||||||
t
|
|
||||||
} else {
|
|
||||||
panic!("RandomNormal output type is no Tensor.");
|
|
||||||
};
|
|
||||||
|
|
||||||
let mean = node
|
let mean = node
|
||||||
.attrs
|
.attrs
|
||||||
|
@ -480,11 +470,12 @@ impl ParsedOnnxGraph {
|
||||||
// Additional types needed for ConstantOfShape:
|
// Additional types needed for ConstantOfShape:
|
||||||
use crate::burn::node::constant_of_shape::ConstantValue;
|
use crate::burn::node::constant_of_shape::ConstantValue;
|
||||||
|
|
||||||
let input = node
|
let input = Type::from(
|
||||||
.inputs
|
node.inputs
|
||||||
.first()
|
.first()
|
||||||
.expect("ConstantOfShape requires an input tensor");
|
.expect("ConstantOfShape requires an input tensor"),
|
||||||
let output = node.outputs.first().unwrap();
|
);
|
||||||
|
let output = Type::from(node.outputs.first().unwrap());
|
||||||
|
|
||||||
// The value of the output elements.Should be a one-element tensor.
|
// The value of the output elements.Should be a one-element tensor.
|
||||||
// If not specified, it defaults to a tensor of value 0 and datatype float32
|
// If not specified, it defaults to a tensor of value 0 and datatype float32
|
||||||
|
@ -504,7 +495,7 @@ impl ParsedOnnxGraph {
|
||||||
})
|
})
|
||||||
.unwrap_or(ConstantValue::Float32(0.0f32));
|
.unwrap_or(ConstantValue::Float32(0.0f32));
|
||||||
|
|
||||||
ConstantOfShapeNode::new(Type::from(input), Type::from(output), value)
|
ConstantOfShapeNode::new(input, output, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_conversion(node: Node) -> BinaryNode {
|
fn add_conversion(node: Node) -> BinaryNode {
|
||||||
|
@ -1082,18 +1073,27 @@ impl ParsedOnnxGraph {
|
||||||
UnaryNode::exp(input, output)
|
UnaryNode::exp(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn expand_conversion(node: Node) -> ExpandNode {
|
fn expand_conversion(mut node: Node) -> ExpandNode {
|
||||||
let input = TensorType::from(node.inputs.first().unwrap());
|
let input = TensorType::from(node.inputs.first().unwrap());
|
||||||
let mut output = TensorType::from(node.outputs.first().unwrap());
|
|
||||||
let shape = expand_config(&node);
|
let shape = expand_config(&node);
|
||||||
output.dim = match &shape {
|
|
||||||
ExpandShape::Static(s) => s.len(),
|
|
||||||
ExpandShape::Runtime(Type::Shape(s)) => s.dim,
|
|
||||||
ExpandShape::Runtime(Type::Tensor(t)) => t.shape.as_ref().unwrap()[0],
|
|
||||||
_ => panic!("Invalid ExpandShape {shape:?}!"),
|
|
||||||
};
|
|
||||||
|
|
||||||
ExpandNode::new(input, output, shape)
|
// dim_inference left the dim at zero, so it needs to be filled before converting to TensorType:
|
||||||
|
assert_eq!(
|
||||||
|
node.outputs.len(),
|
||||||
|
1,
|
||||||
|
"ExpandNode must have exactly 1 output!"
|
||||||
|
);
|
||||||
|
let mut output_arg = node.outputs.pop().unwrap();
|
||||||
|
if let ArgType::Tensor(output_arg_tensor) = &mut output_arg.ty {
|
||||||
|
output_arg_tensor.dim = match &shape {
|
||||||
|
ExpandShape::Static(s) => s.len(),
|
||||||
|
ExpandShape::Runtime(Type::Shape(s)) => s.dim,
|
||||||
|
ExpandShape::Runtime(Type::Tensor(t)) => t.shape.as_ref().unwrap()[0],
|
||||||
|
_ => panic!("Invalid ExpandShape {shape:?}!"),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
ExpandNode::new(input, TensorType::from(&output_arg), shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn neg_conversion(node: Node) -> UnaryNode {
|
fn neg_conversion(node: Node) -> UnaryNode {
|
||||||
|
@ -1236,18 +1236,21 @@ impl From<&OnnxArgument> for TensorType {
|
||||||
ArgType::Tensor(OnnxTensorType {
|
ArgType::Tensor(OnnxTensorType {
|
||||||
elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64,
|
elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64,
|
||||||
dim,
|
dim,
|
||||||
|
shape,
|
||||||
..
|
..
|
||||||
}) => TensorType::new_float(arg.name.clone(), *dim),
|
}) => TensorType::new_float_with_shape(arg.name.clone(), *dim, shape.clone()),
|
||||||
ArgType::Tensor(OnnxTensorType {
|
ArgType::Tensor(OnnxTensorType {
|
||||||
elem_type: ElementType::Int32 | ElementType::Int64,
|
elem_type: ElementType::Int32 | ElementType::Int64,
|
||||||
dim,
|
dim,
|
||||||
|
shape,
|
||||||
..
|
..
|
||||||
}) => TensorType::new_int(arg.name.clone(), *dim),
|
}) => TensorType::new_int_with_shape(arg.name.clone(), *dim, shape.clone()),
|
||||||
ArgType::Tensor(OnnxTensorType {
|
ArgType::Tensor(OnnxTensorType {
|
||||||
elem_type: ElementType::Bool,
|
elem_type: ElementType::Bool,
|
||||||
dim,
|
dim,
|
||||||
|
shape,
|
||||||
..
|
..
|
||||||
}) => TensorType::new_bool(arg.name.clone(), *dim),
|
}) => TensorType::new_bool_with_shape(arg.name.clone(), *dim, shape.clone()),
|
||||||
_ => panic!("Can't transform {:?} to tensor.", arg.ty),
|
_ => panic!("Can't transform {:?} to tensor.", arg.ty),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use core::cmp::max;
|
use core::cmp::max;
|
||||||
use core::panic;
|
use core::panic;
|
||||||
|
|
||||||
|
use log::debug;
|
||||||
use protobuf::Enum;
|
use protobuf::Enum;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -12,7 +13,7 @@ use crate::{
|
||||||
/// Infer the dimension of each output tensor and update them.
|
/// Infer the dimension of each output tensor and update them.
|
||||||
pub fn dim_inference(node: &mut Node) {
|
pub fn dim_inference(node: &mut Node) {
|
||||||
match node.node_type {
|
match node.node_type {
|
||||||
NodeType::Add => same_as_input(node),
|
NodeType::Add => same_as_input_broadcast(node),
|
||||||
NodeType::ArgMax => argmax_update_outputs(node),
|
NodeType::ArgMax => argmax_update_outputs(node),
|
||||||
NodeType::AveragePool1d => same_as_input(node),
|
NodeType::AveragePool1d => same_as_input(node),
|
||||||
NodeType::AveragePool2d => same_as_input(node),
|
NodeType::AveragePool2d => same_as_input(node),
|
||||||
|
@ -21,12 +22,13 @@ pub fn dim_inference(node: &mut Node) {
|
||||||
NodeType::Clip => same_as_input(node),
|
NodeType::Clip => same_as_input(node),
|
||||||
NodeType::Concat => concat_update_outputs(node),
|
NodeType::Concat => concat_update_outputs(node),
|
||||||
NodeType::Constant => constant_update_outputs(node),
|
NodeType::Constant => constant_update_outputs(node),
|
||||||
|
NodeType::ConstantOfShape => constant_of_shape_update_output(node),
|
||||||
NodeType::Conv1d => conv1d_update_outputs(node),
|
NodeType::Conv1d => conv1d_update_outputs(node),
|
||||||
NodeType::Conv2d => conv2d_update_outputs(node),
|
NodeType::Conv2d => conv2d_update_outputs(node),
|
||||||
NodeType::Cos => same_as_input(node),
|
NodeType::Cos => same_as_input(node),
|
||||||
NodeType::Div => same_as_input(node),
|
NodeType::Div => same_as_input_broadcast(node),
|
||||||
NodeType::Dropout => same_as_input(node),
|
NodeType::Dropout => same_as_input(node),
|
||||||
NodeType::Equal => equal_update_outputs(node),
|
NodeType::Equal => elementwise_comparsion_outputs(node),
|
||||||
NodeType::Erf => same_as_input(node),
|
NodeType::Erf => same_as_input(node),
|
||||||
NodeType::Exp => same_as_input(node),
|
NodeType::Exp => same_as_input(node),
|
||||||
NodeType::Expand => expand_update_outputs(node),
|
NodeType::Expand => expand_update_outputs(node),
|
||||||
|
@ -34,26 +36,31 @@ pub fn dim_inference(node: &mut Node) {
|
||||||
NodeType::Gelu => same_as_input(node),
|
NodeType::Gelu => same_as_input(node),
|
||||||
NodeType::Gather => gather_update_outputs(node),
|
NodeType::Gather => gather_update_outputs(node),
|
||||||
NodeType::GatherElements => same_as_input(node),
|
NodeType::GatherElements => same_as_input(node),
|
||||||
|
NodeType::Greater => elementwise_comparsion_outputs(node),
|
||||||
|
NodeType::GreaterOrEqual => elementwise_comparsion_outputs(node),
|
||||||
NodeType::HardSigmoid => same_as_input(node),
|
NodeType::HardSigmoid => same_as_input(node),
|
||||||
NodeType::GlobalAveragePool => same_as_input(node),
|
NodeType::GlobalAveragePool => same_as_input(node),
|
||||||
NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node),
|
NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node),
|
||||||
NodeType::LayerNormalization => same_as_input(node),
|
NodeType::LayerNormalization => same_as_input(node),
|
||||||
|
NodeType::LeakyRelu => same_as_input(node),
|
||||||
|
NodeType::Less => elementwise_comparsion_outputs(node),
|
||||||
|
NodeType::LessOrEqual => elementwise_comparsion_outputs(node),
|
||||||
NodeType::Linear => linear_update_outputs(node),
|
NodeType::Linear => linear_update_outputs(node),
|
||||||
NodeType::Log => same_as_input(node),
|
NodeType::Log => same_as_input(node),
|
||||||
NodeType::LogSoftmax => same_as_input(node),
|
NodeType::LogSoftmax => same_as_input(node),
|
||||||
NodeType::MatMul => matmul_update_outputs(node),
|
NodeType::MatMul => matmul_update_outputs(node),
|
||||||
NodeType::Min => same_as_input(node),
|
NodeType::Max => same_as_input_broadcast(node),
|
||||||
NodeType::Max => same_as_input(node),
|
|
||||||
NodeType::MaxPool1d => same_as_input(node),
|
NodeType::MaxPool1d => same_as_input(node),
|
||||||
NodeType::MaxPool2d => same_as_input(node),
|
NodeType::MaxPool2d => same_as_input(node),
|
||||||
|
NodeType::Min => same_as_input_broadcast(node),
|
||||||
NodeType::Mul => same_as_input(node),
|
NodeType::Mul => same_as_input(node),
|
||||||
NodeType::Neg => same_as_input(node),
|
NodeType::Neg => same_as_input(node),
|
||||||
NodeType::Not => same_as_input(node),
|
NodeType::Not => same_as_input(node),
|
||||||
NodeType::Pad => same_as_input(node),
|
NodeType::Pad => same_as_input(node),
|
||||||
NodeType::Greater => greater_update_outputs(node),
|
NodeType::PRelu => same_as_input_broadcast(node),
|
||||||
NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node),
|
NodeType::Pow => same_as_input_broadcast(node),
|
||||||
NodeType::Less => less_update_outputs(node),
|
NodeType::RandomNormal => random_update_output(node),
|
||||||
NodeType::LessOrEqual => less_or_equal_update_outputs(node),
|
NodeType::RandomUniform => random_update_output(node),
|
||||||
NodeType::Range => range_update_outputs(node),
|
NodeType::Range => range_update_outputs(node),
|
||||||
NodeType::Reciprocal => same_as_input(node),
|
NodeType::Reciprocal => same_as_input(node),
|
||||||
NodeType::ReduceMax => reduce_max_update_outputs(node),
|
NodeType::ReduceMax => reduce_max_update_outputs(node),
|
||||||
|
@ -70,20 +77,14 @@ pub fn dim_inference(node: &mut Node) {
|
||||||
NodeType::Sin => same_as_input(node),
|
NodeType::Sin => same_as_input(node),
|
||||||
NodeType::Slice => same_as_input(node),
|
NodeType::Slice => same_as_input(node),
|
||||||
NodeType::Softmax => same_as_input(node),
|
NodeType::Softmax => same_as_input(node),
|
||||||
|
NodeType::Squeeze => squeeze_update_output(node),
|
||||||
NodeType::Sqrt => same_as_input(node),
|
NodeType::Sqrt => same_as_input(node),
|
||||||
NodeType::Sub => sub_update_outputs(node),
|
NodeType::Sub => same_as_input_broadcast(node),
|
||||||
NodeType::Sum => same_as_input(node),
|
NodeType::Sum => same_as_input_broadcast(node),
|
||||||
NodeType::Tanh => same_as_input(node),
|
NodeType::Tanh => same_as_input(node),
|
||||||
NodeType::Transpose => same_as_input(node),
|
NodeType::Transpose => same_as_input(node),
|
||||||
NodeType::Unsqueeze => unsqueeze_update_output(node),
|
NodeType::Unsqueeze => unsqueeze_update_output(node),
|
||||||
NodeType::Pow => same_as_input(node),
|
|
||||||
NodeType::LeakyRelu => same_as_input(node),
|
|
||||||
NodeType::PRelu => same_as_input(node),
|
|
||||||
NodeType::Where => where_update_outputs(node),
|
NodeType::Where => where_update_outputs(node),
|
||||||
NodeType::Squeeze => squeeze_update_output(node),
|
|
||||||
NodeType::RandomUniform => random_update_output(node),
|
|
||||||
NodeType::RandomNormal => random_update_output(node),
|
|
||||||
NodeType::ConstantOfShape => constant_of_shape_update_output(node),
|
|
||||||
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
|
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
|
||||||
_ => temporary_pass_through_stub(node),
|
_ => temporary_pass_through_stub(node),
|
||||||
}
|
}
|
||||||
|
@ -108,6 +109,9 @@ fn constant_update_outputs(node: &mut Node) {
|
||||||
node.outputs[0].ty = match matched_value {
|
node.outputs[0].ty = match matched_value {
|
||||||
Some(value) => match &value {
|
Some(value) => match &value {
|
||||||
// The value is stored in an attribute
|
// The value is stored in an attribute
|
||||||
|
AttributeValue::Tensor(tensor) if tensor.dim == 0 => {
|
||||||
|
ArgType::Scalar(tensor.elem_type.clone())
|
||||||
|
}
|
||||||
AttributeValue::Tensor(tensor) => ArgType::Tensor(TensorType {
|
AttributeValue::Tensor(tensor) => ArgType::Tensor(TensorType {
|
||||||
elem_type: tensor.elem_type.clone(),
|
elem_type: tensor.elem_type.clone(),
|
||||||
dim: tensor.dim,
|
dim: tensor.dim,
|
||||||
|
@ -319,54 +323,6 @@ fn reshape_update_outputs(node: &mut Node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn greater_update_outputs(node: &mut Node) {
|
|
||||||
match &node.inputs[0].ty {
|
|
||||||
ArgType::Tensor(tensor) => {
|
|
||||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
|
||||||
elem_type: ElementType::Bool,
|
|
||||||
..tensor.clone()
|
|
||||||
});
|
|
||||||
}
|
|
||||||
_ => panic!("Only tensor input is valid"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn less_update_outputs(node: &mut Node) {
|
|
||||||
match &node.inputs[0].ty {
|
|
||||||
ArgType::Tensor(tensor) => {
|
|
||||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
|
||||||
elem_type: ElementType::Bool,
|
|
||||||
..tensor.clone()
|
|
||||||
});
|
|
||||||
}
|
|
||||||
_ => panic!("Only tensor input is valid"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn greater_or_equal_update_outputs(node: &mut Node) {
|
|
||||||
match &node.inputs[0].ty {
|
|
||||||
ArgType::Tensor(tensor) => {
|
|
||||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
|
||||||
elem_type: ElementType::Bool,
|
|
||||||
..tensor.clone()
|
|
||||||
});
|
|
||||||
}
|
|
||||||
_ => panic!("Only tensor input is valid"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn less_or_equal_update_outputs(node: &mut Node) {
|
|
||||||
match &node.inputs[0].ty {
|
|
||||||
ArgType::Tensor(tensor) => {
|
|
||||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
|
||||||
elem_type: ElementType::Bool,
|
|
||||||
..tensor.clone()
|
|
||||||
});
|
|
||||||
}
|
|
||||||
_ => panic!("Only tensor input is valid"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn reduce_mean_update_outputs(node: &mut Node) {
|
fn reduce_mean_update_outputs(node: &mut Node) {
|
||||||
if node.inputs.len() != 1 {
|
if node.inputs.len() != 1 {
|
||||||
panic!("Mean: multiple inputs are not supported");
|
panic!("Mean: multiple inputs are not supported");
|
||||||
|
@ -455,18 +411,27 @@ fn squeeze_update_output(node: &mut Node) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sub_update_outputs(node: &mut Node) {
|
/// Updates the output type for operations that take more than one Tensor/Scalar of the a type
|
||||||
node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) {
|
/// and returns the same type while supporting broadcasting
|
||||||
(ArgType::Scalar(_lhs), ArgType::Scalar(rhs)) => ArgType::Scalar(rhs),
|
///
|
||||||
(ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs),
|
/// This is mostly the elementwise math operations, i.e. add, sub, mul, max etc.
|
||||||
(ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs),
|
fn same_as_input_broadcast(node: &mut Node) {
|
||||||
// Support broadcasting for lhs/rhs
|
if node.inputs.iter().all(|input| input.ty.is_scalar()) {
|
||||||
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim > rhs.dim => ArgType::Tensor(lhs),
|
// If all inputs are scalar, the output is a scalar as well
|
||||||
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim <= rhs.dim => ArgType::Tensor(rhs),
|
node.outputs[0].ty = node.inputs[0].ty.clone();
|
||||||
_ => {
|
} else {
|
||||||
panic!("Only tensor-scalar inputs are valid.");
|
// else, if any input is a Tensor, use it's datatype,
|
||||||
}
|
// or the input consists solely from Scalars and Shapes,
|
||||||
};
|
// which should result in another Shape
|
||||||
|
node.outputs[0].ty = node
|
||||||
|
.inputs
|
||||||
|
.iter()
|
||||||
|
.find(|input| input.ty.is_tensor())
|
||||||
|
.map(|input| input.ty.clone())
|
||||||
|
.unwrap_or_else(|| ArgType::Shape(0)); //Shape dim will be set by broadcast calculation
|
||||||
|
|
||||||
|
set_broadcasting_output_shape(node);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Update the output tensor dimension based on the "axes" attribute or the second input
|
/// Update the output tensor dimension based on the "axes" attribute or the second input
|
||||||
|
@ -519,21 +484,34 @@ fn temporary_pass_through_stub(node: &mut Node) {
|
||||||
node.outputs[0].ty = node.inputs[0].ty.clone();
|
node.outputs[0].ty = node.inputs[0].ty.clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn equal_update_outputs(node: &mut Node) {
|
/// Sets the output for binary operators resulting in a boolean output,
|
||||||
let input1_type = node.inputs[0].ty.clone();
|
/// i.e., comparison operators like Equal, Greater, Less, etc.
|
||||||
|
///
|
||||||
|
/// Support for broadcasting is assumed
|
||||||
|
fn elementwise_comparsion_outputs(node: &mut Node) {
|
||||||
|
let input1_type = &node.inputs[0].ty;
|
||||||
|
let input2_type = &node.inputs[1].ty;
|
||||||
|
|
||||||
match input1_type {
|
match (input1_type, input2_type) {
|
||||||
ArgType::Tensor(tensor) => {
|
(ArgType::Tensor(tensor), _) | (_, ArgType::Tensor(tensor)) => {
|
||||||
// if the input is a tensor, the output is a tensor of bool
|
// if one of the inputs is a tensor, the output is a tensor of bool
|
||||||
|
assert_ne!(
|
||||||
|
tensor.dim, 0,
|
||||||
|
"Got a rank 0 Tensor, that should have been a Scalar!"
|
||||||
|
);
|
||||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||||
elem_type: ElementType::Bool,
|
elem_type: ElementType::Bool,
|
||||||
..tensor
|
..tensor.clone()
|
||||||
});
|
});
|
||||||
|
set_broadcasting_output_shape(node);
|
||||||
}
|
}
|
||||||
ArgType::Scalar(_) => {
|
(ArgType::Scalar(_), ArgType::Scalar(_)) => {
|
||||||
|
// if both inputs are scalars, the result is a scalar bool
|
||||||
node.outputs[0].ty = ArgType::Scalar(ElementType::Bool);
|
node.outputs[0].ty = ArgType::Scalar(ElementType::Bool);
|
||||||
}
|
}
|
||||||
_ => panic!("Only tensor input is valid"),
|
_ => {
|
||||||
|
panic!("Invalid input types for comparison op: {input1_type:?}, {input2_type:?}")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -788,11 +766,7 @@ fn reduce_sum_update_outputs(node: &mut Node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn where_update_outputs(node: &mut Node) {
|
fn where_update_outputs(node: &mut Node) {
|
||||||
match (
|
match (&node.inputs[0].ty, &node.inputs[1].ty, &node.inputs[2].ty) {
|
||||||
node.inputs[0].ty.clone(),
|
|
||||||
node.inputs[1].ty.clone(),
|
|
||||||
node.inputs[2].ty.clone(),
|
|
||||||
) {
|
|
||||||
(ArgType::Tensor(condition), ArgType::Tensor(x), ArgType::Tensor(y)) => {
|
(ArgType::Tensor(condition), ArgType::Tensor(x), ArgType::Tensor(y)) => {
|
||||||
// With broadcasting support, output dim has to be computed based on the inputs
|
// With broadcasting support, output dim has to be computed based on the inputs
|
||||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||||
|
@ -800,6 +774,7 @@ fn where_update_outputs(node: &mut Node) {
|
||||||
dim: max(condition.dim, max(x.dim, y.dim)),
|
dim: max(condition.dim, max(x.dim, y.dim)),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
});
|
});
|
||||||
|
set_broadcasting_output_shape(node);
|
||||||
}
|
}
|
||||||
_ => panic!("Only tensor input is valid"),
|
_ => panic!("Only tensor input is valid"),
|
||||||
}
|
}
|
||||||
|
@ -841,3 +816,68 @@ fn gather_update_outputs(node: &mut Node) {
|
||||||
ty => panic!("Only tensor/shape input is valid but received: {:?}", ty),
|
ty => panic!("Only tensor/shape input is valid but received: {:?}", ty),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// If all input shapes are known,
|
||||||
|
/// calculates the rank and shape of the output tensor for Operators supporting
|
||||||
|
/// [broadcasting](https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md)
|
||||||
|
fn set_broadcasting_output_shape(node: &mut Node) {
|
||||||
|
let mut reverse_out_shape: Vec<usize> = vec![1];
|
||||||
|
for (idx, input_type) in node.inputs.iter().enumerate() {
|
||||||
|
match &input_type.ty {
|
||||||
|
ArgType::Tensor(t) => {
|
||||||
|
if let Some(shape) = &t.shape {
|
||||||
|
for (rev_idx, dimension) in shape.iter().rev().enumerate() {
|
||||||
|
if let Some(current_out_dim) = reverse_out_shape.get_mut(rev_idx) {
|
||||||
|
if *dimension == 1 {
|
||||||
|
// dimension already has a value, this tensor can be broadcast
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if current_out_dim != dimension && *current_out_dim != 1 {
|
||||||
|
panic!("Invalid shape for broadcasting - the dimension from the {rev_idx}. to last position has conflicting values {current_out_dim} and {dimension} from different inputs");
|
||||||
|
}
|
||||||
|
*current_out_dim = *dimension;
|
||||||
|
} else {
|
||||||
|
reverse_out_shape.push(*dimension);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
debug!("Input {idx} has no known shape, cannot predict broadcast result shape");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ArgType::Scalar(_) => {
|
||||||
|
// reverse_out_shape already starts with [1]
|
||||||
|
}
|
||||||
|
ArgType::Shape(s) => {
|
||||||
|
// Shape is treated like a 1-D Tensor
|
||||||
|
let current_out_dim = &mut reverse_out_shape[0];
|
||||||
|
if *current_out_dim != 1 && *current_out_dim != *s {
|
||||||
|
panic!("Invalid shape for broadcasting - the last position has conflicting values {current_out_dim} and {} from different inputs", s);
|
||||||
|
}
|
||||||
|
*current_out_dim = *s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we get to this point without returning, reverse_out_shape will be final
|
||||||
|
let mut out_shape = reverse_out_shape;
|
||||||
|
out_shape.reverse(); //unreverse it
|
||||||
|
|
||||||
|
match &mut node.outputs[0].ty {
|
||||||
|
ArgType::Tensor(t) => {
|
||||||
|
t.dim = out_shape.len();
|
||||||
|
t.shape = Some(out_shape);
|
||||||
|
}
|
||||||
|
ArgType::Scalar(_) => {
|
||||||
|
if out_shape.len() > 1 || out_shape[0] > 1 {
|
||||||
|
panic!("Output is a Scalar, but broadcasting results in tensor shape {out_shape:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ArgType::Shape(s) => {
|
||||||
|
if out_shape.len() > 1 {
|
||||||
|
panic!("Output is a Shape, but broadcasting results in higher-rank tensor shape {out_shape:?}")
|
||||||
|
}
|
||||||
|
*s = out_shape[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -128,6 +128,15 @@ impl Default for ArgType {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ArgType {
|
||||||
|
pub fn is_scalar(&self) -> bool {
|
||||||
|
matches!(self, Self::Scalar(_))
|
||||||
|
}
|
||||||
|
pub fn is_tensor(&self) -> bool {
|
||||||
|
matches!(self, Self::Tensor(_))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Argument {
|
impl Argument {
|
||||||
pub fn new(name: String) -> Self {
|
pub fn new(name: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|
|
@ -242,20 +242,26 @@ impl TryFrom<ValueInfoProto> for Argument {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let tensor_type = TensorType {
|
let ty = if tensor_proto.shape.dim.is_empty() {
|
||||||
dim: tensor_proto.shape.dim.len(),
|
// tensor_proto describes a scalar
|
||||||
elem_type,
|
ArgType::Scalar(elem_type)
|
||||||
shape: Some(
|
} else {
|
||||||
tensor_proto
|
// tensor_proto describes a tensor
|
||||||
.shape
|
let tensor_type = TensorType {
|
||||||
.dim
|
dim: tensor_proto.shape.dim.len(),
|
||||||
.iter()
|
elem_type,
|
||||||
.map(|x| x.dim_value() as Dim)
|
shape: Some(
|
||||||
.collect(),
|
tensor_proto
|
||||||
),
|
.shape
|
||||||
};
|
.dim
|
||||||
|
.iter()
|
||||||
|
.map(|x| x.dim_value() as Dim)
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
let ty = ArgType::Tensor(tensor_type);
|
ArgType::Tensor(tensor_type)
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Argument {
|
Ok(Argument {
|
||||||
ty,
|
ty,
|
||||||
|
|
Loading…
Reference in New Issue