mirror of https://github.com/tracel-ai/burn.git
Improve ONNX import tensor shape tracking (#2213)
- Calculate result of broadcasting in dim_inference - keep Shape info when converting from Argument to TensorType - Remove a few sources of Dim = 0 Tensors, create Scalars instead - Clean up dim_inference a bit
This commit is contained in:
parent
2f4c5ac0a1
commit
e8ea9e27c2
|
@ -143,6 +143,10 @@ impl TensorType {
|
|||
);
|
||||
}
|
||||
let formatted_name = Self::format_name(name.as_ref());
|
||||
assert_ne!(
|
||||
dim, 0,
|
||||
"Trying to create TensorType with dim = 0 - should be a Scalar instead!"
|
||||
);
|
||||
Self {
|
||||
name: Ident::new(&formatted_name, Span::call_site()),
|
||||
dim,
|
||||
|
@ -151,15 +155,39 @@ impl TensorType {
|
|||
}
|
||||
}
|
||||
pub fn new_float<S: AsRef<str>>(name: S, dim: usize) -> Self {
|
||||
Self::new(name, dim, TensorKind::Float, None)
|
||||
Self::new_float_with_shape(name, dim, None)
|
||||
}
|
||||
|
||||
pub fn new_float_with_shape<S: AsRef<str>>(
|
||||
name: S,
|
||||
dim: usize,
|
||||
shape: Option<Vec<usize>>,
|
||||
) -> Self {
|
||||
Self::new(name, dim, TensorKind::Float, shape)
|
||||
}
|
||||
|
||||
pub fn new_int<S: AsRef<str>>(name: S, dim: usize) -> Self {
|
||||
Self::new(name, dim, TensorKind::Int, None)
|
||||
Self::new_int_with_shape(name, dim, None)
|
||||
}
|
||||
|
||||
pub fn new_int_with_shape<S: AsRef<str>>(
|
||||
name: S,
|
||||
dim: usize,
|
||||
shape: Option<Vec<usize>>,
|
||||
) -> Self {
|
||||
Self::new(name, dim, TensorKind::Int, shape)
|
||||
}
|
||||
|
||||
pub fn new_bool<S: AsRef<str>>(name: S, dim: usize) -> Self {
|
||||
Self::new(name, dim, TensorKind::Bool, None)
|
||||
Self::new_bool_with_shape(name, dim, None)
|
||||
}
|
||||
|
||||
pub fn new_bool_with_shape<S: AsRef<str>>(
|
||||
name: S,
|
||||
dim: usize,
|
||||
shape: Option<Vec<usize>>,
|
||||
) -> Self {
|
||||
Self::new(name, dim, TensorKind::Bool, shape)
|
||||
}
|
||||
|
||||
pub fn ty(&self) -> TokenStream {
|
||||
|
|
|
@ -424,12 +424,7 @@ impl ParsedOnnxGraph {
|
|||
|
||||
fn random_uniform_conversion(node: Node) -> RandomUniformNode {
|
||||
let output = node.outputs.first().unwrap();
|
||||
// cannot use output.to_tensor_type() here, since it drops the shape info...
|
||||
let output_type = if let Type::Tensor(t) = Type::from(output) {
|
||||
t
|
||||
} else {
|
||||
panic!("RandomUniform output type is no Tensor.");
|
||||
};
|
||||
let output_type = TensorType::from(output);
|
||||
|
||||
let high = node
|
||||
.attrs
|
||||
|
@ -451,12 +446,7 @@ impl ParsedOnnxGraph {
|
|||
|
||||
fn random_normal_conversion(node: Node) -> RandomNormalNode {
|
||||
let output = node.outputs.first().unwrap();
|
||||
// cannot use output.to_tensor_type() here, since it drops the shape info...
|
||||
let output_type = if let Type::Tensor(t) = Type::from(output) {
|
||||
t
|
||||
} else {
|
||||
panic!("RandomNormal output type is no Tensor.");
|
||||
};
|
||||
let output_type = TensorType::from(output);
|
||||
|
||||
let mean = node
|
||||
.attrs
|
||||
|
@ -480,11 +470,12 @@ impl ParsedOnnxGraph {
|
|||
// Additional types needed for ConstantOfShape:
|
||||
use crate::burn::node::constant_of_shape::ConstantValue;
|
||||
|
||||
let input = node
|
||||
.inputs
|
||||
let input = Type::from(
|
||||
node.inputs
|
||||
.first()
|
||||
.expect("ConstantOfShape requires an input tensor");
|
||||
let output = node.outputs.first().unwrap();
|
||||
.expect("ConstantOfShape requires an input tensor"),
|
||||
);
|
||||
let output = Type::from(node.outputs.first().unwrap());
|
||||
|
||||
// The value of the output elements.Should be a one-element tensor.
|
||||
// If not specified, it defaults to a tensor of value 0 and datatype float32
|
||||
|
@ -504,7 +495,7 @@ impl ParsedOnnxGraph {
|
|||
})
|
||||
.unwrap_or(ConstantValue::Float32(0.0f32));
|
||||
|
||||
ConstantOfShapeNode::new(Type::from(input), Type::from(output), value)
|
||||
ConstantOfShapeNode::new(input, output, value)
|
||||
}
|
||||
|
||||
fn add_conversion(node: Node) -> BinaryNode {
|
||||
|
@ -1082,18 +1073,27 @@ impl ParsedOnnxGraph {
|
|||
UnaryNode::exp(input, output)
|
||||
}
|
||||
|
||||
fn expand_conversion(node: Node) -> ExpandNode {
|
||||
fn expand_conversion(mut node: Node) -> ExpandNode {
|
||||
let input = TensorType::from(node.inputs.first().unwrap());
|
||||
let mut output = TensorType::from(node.outputs.first().unwrap());
|
||||
let shape = expand_config(&node);
|
||||
output.dim = match &shape {
|
||||
|
||||
// dim_inference left the dim at zero, so it needs to be filled before converting to TensorType:
|
||||
assert_eq!(
|
||||
node.outputs.len(),
|
||||
1,
|
||||
"ExpandNode must have exactly 1 output!"
|
||||
);
|
||||
let mut output_arg = node.outputs.pop().unwrap();
|
||||
if let ArgType::Tensor(output_arg_tensor) = &mut output_arg.ty {
|
||||
output_arg_tensor.dim = match &shape {
|
||||
ExpandShape::Static(s) => s.len(),
|
||||
ExpandShape::Runtime(Type::Shape(s)) => s.dim,
|
||||
ExpandShape::Runtime(Type::Tensor(t)) => t.shape.as_ref().unwrap()[0],
|
||||
_ => panic!("Invalid ExpandShape {shape:?}!"),
|
||||
};
|
||||
}
|
||||
|
||||
ExpandNode::new(input, output, shape)
|
||||
ExpandNode::new(input, TensorType::from(&output_arg), shape)
|
||||
}
|
||||
|
||||
fn neg_conversion(node: Node) -> UnaryNode {
|
||||
|
@ -1236,18 +1236,21 @@ impl From<&OnnxArgument> for TensorType {
|
|||
ArgType::Tensor(OnnxTensorType {
|
||||
elem_type: ElementType::Float16 | ElementType::Float32 | ElementType::Float64,
|
||||
dim,
|
||||
shape,
|
||||
..
|
||||
}) => TensorType::new_float(arg.name.clone(), *dim),
|
||||
}) => TensorType::new_float_with_shape(arg.name.clone(), *dim, shape.clone()),
|
||||
ArgType::Tensor(OnnxTensorType {
|
||||
elem_type: ElementType::Int32 | ElementType::Int64,
|
||||
dim,
|
||||
shape,
|
||||
..
|
||||
}) => TensorType::new_int(arg.name.clone(), *dim),
|
||||
}) => TensorType::new_int_with_shape(arg.name.clone(), *dim, shape.clone()),
|
||||
ArgType::Tensor(OnnxTensorType {
|
||||
elem_type: ElementType::Bool,
|
||||
dim,
|
||||
shape,
|
||||
..
|
||||
}) => TensorType::new_bool(arg.name.clone(), *dim),
|
||||
}) => TensorType::new_bool_with_shape(arg.name.clone(), *dim, shape.clone()),
|
||||
_ => panic!("Can't transform {:?} to tensor.", arg.ty),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use core::cmp::max;
|
||||
use core::panic;
|
||||
|
||||
use log::debug;
|
||||
use protobuf::Enum;
|
||||
|
||||
use crate::{
|
||||
|
@ -12,7 +13,7 @@ use crate::{
|
|||
/// Infer the dimension of each output tensor and update them.
|
||||
pub fn dim_inference(node: &mut Node) {
|
||||
match node.node_type {
|
||||
NodeType::Add => same_as_input(node),
|
||||
NodeType::Add => same_as_input_broadcast(node),
|
||||
NodeType::ArgMax => argmax_update_outputs(node),
|
||||
NodeType::AveragePool1d => same_as_input(node),
|
||||
NodeType::AveragePool2d => same_as_input(node),
|
||||
|
@ -21,12 +22,13 @@ pub fn dim_inference(node: &mut Node) {
|
|||
NodeType::Clip => same_as_input(node),
|
||||
NodeType::Concat => concat_update_outputs(node),
|
||||
NodeType::Constant => constant_update_outputs(node),
|
||||
NodeType::ConstantOfShape => constant_of_shape_update_output(node),
|
||||
NodeType::Conv1d => conv1d_update_outputs(node),
|
||||
NodeType::Conv2d => conv2d_update_outputs(node),
|
||||
NodeType::Cos => same_as_input(node),
|
||||
NodeType::Div => same_as_input(node),
|
||||
NodeType::Div => same_as_input_broadcast(node),
|
||||
NodeType::Dropout => same_as_input(node),
|
||||
NodeType::Equal => equal_update_outputs(node),
|
||||
NodeType::Equal => elementwise_comparsion_outputs(node),
|
||||
NodeType::Erf => same_as_input(node),
|
||||
NodeType::Exp => same_as_input(node),
|
||||
NodeType::Expand => expand_update_outputs(node),
|
||||
|
@ -34,26 +36,31 @@ pub fn dim_inference(node: &mut Node) {
|
|||
NodeType::Gelu => same_as_input(node),
|
||||
NodeType::Gather => gather_update_outputs(node),
|
||||
NodeType::GatherElements => same_as_input(node),
|
||||
NodeType::Greater => elementwise_comparsion_outputs(node),
|
||||
NodeType::GreaterOrEqual => elementwise_comparsion_outputs(node),
|
||||
NodeType::HardSigmoid => same_as_input(node),
|
||||
NodeType::GlobalAveragePool => same_as_input(node),
|
||||
NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node),
|
||||
NodeType::LayerNormalization => same_as_input(node),
|
||||
NodeType::LeakyRelu => same_as_input(node),
|
||||
NodeType::Less => elementwise_comparsion_outputs(node),
|
||||
NodeType::LessOrEqual => elementwise_comparsion_outputs(node),
|
||||
NodeType::Linear => linear_update_outputs(node),
|
||||
NodeType::Log => same_as_input(node),
|
||||
NodeType::LogSoftmax => same_as_input(node),
|
||||
NodeType::MatMul => matmul_update_outputs(node),
|
||||
NodeType::Min => same_as_input(node),
|
||||
NodeType::Max => same_as_input(node),
|
||||
NodeType::Max => same_as_input_broadcast(node),
|
||||
NodeType::MaxPool1d => same_as_input(node),
|
||||
NodeType::MaxPool2d => same_as_input(node),
|
||||
NodeType::Min => same_as_input_broadcast(node),
|
||||
NodeType::Mul => same_as_input(node),
|
||||
NodeType::Neg => same_as_input(node),
|
||||
NodeType::Not => same_as_input(node),
|
||||
NodeType::Pad => same_as_input(node),
|
||||
NodeType::Greater => greater_update_outputs(node),
|
||||
NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node),
|
||||
NodeType::Less => less_update_outputs(node),
|
||||
NodeType::LessOrEqual => less_or_equal_update_outputs(node),
|
||||
NodeType::PRelu => same_as_input_broadcast(node),
|
||||
NodeType::Pow => same_as_input_broadcast(node),
|
||||
NodeType::RandomNormal => random_update_output(node),
|
||||
NodeType::RandomUniform => random_update_output(node),
|
||||
NodeType::Range => range_update_outputs(node),
|
||||
NodeType::Reciprocal => same_as_input(node),
|
||||
NodeType::ReduceMax => reduce_max_update_outputs(node),
|
||||
|
@ -70,20 +77,14 @@ pub fn dim_inference(node: &mut Node) {
|
|||
NodeType::Sin => same_as_input(node),
|
||||
NodeType::Slice => same_as_input(node),
|
||||
NodeType::Softmax => same_as_input(node),
|
||||
NodeType::Squeeze => squeeze_update_output(node),
|
||||
NodeType::Sqrt => same_as_input(node),
|
||||
NodeType::Sub => sub_update_outputs(node),
|
||||
NodeType::Sum => same_as_input(node),
|
||||
NodeType::Sub => same_as_input_broadcast(node),
|
||||
NodeType::Sum => same_as_input_broadcast(node),
|
||||
NodeType::Tanh => same_as_input(node),
|
||||
NodeType::Transpose => same_as_input(node),
|
||||
NodeType::Unsqueeze => unsqueeze_update_output(node),
|
||||
NodeType::Pow => same_as_input(node),
|
||||
NodeType::LeakyRelu => same_as_input(node),
|
||||
NodeType::PRelu => same_as_input(node),
|
||||
NodeType::Where => where_update_outputs(node),
|
||||
NodeType::Squeeze => squeeze_update_output(node),
|
||||
NodeType::RandomUniform => random_update_output(node),
|
||||
NodeType::RandomNormal => random_update_output(node),
|
||||
NodeType::ConstantOfShape => constant_of_shape_update_output(node),
|
||||
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
|
||||
_ => temporary_pass_through_stub(node),
|
||||
}
|
||||
|
@ -108,6 +109,9 @@ fn constant_update_outputs(node: &mut Node) {
|
|||
node.outputs[0].ty = match matched_value {
|
||||
Some(value) => match &value {
|
||||
// The value is stored in an attribute
|
||||
AttributeValue::Tensor(tensor) if tensor.dim == 0 => {
|
||||
ArgType::Scalar(tensor.elem_type.clone())
|
||||
}
|
||||
AttributeValue::Tensor(tensor) => ArgType::Tensor(TensorType {
|
||||
elem_type: tensor.elem_type.clone(),
|
||||
dim: tensor.dim,
|
||||
|
@ -319,54 +323,6 @@ fn reshape_update_outputs(node: &mut Node) {
|
|||
}
|
||||
}
|
||||
|
||||
fn greater_update_outputs(node: &mut Node) {
|
||||
match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => {
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type: ElementType::Bool,
|
||||
..tensor.clone()
|
||||
});
|
||||
}
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
}
|
||||
}
|
||||
|
||||
fn less_update_outputs(node: &mut Node) {
|
||||
match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => {
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type: ElementType::Bool,
|
||||
..tensor.clone()
|
||||
});
|
||||
}
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
}
|
||||
}
|
||||
|
||||
fn greater_or_equal_update_outputs(node: &mut Node) {
|
||||
match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => {
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type: ElementType::Bool,
|
||||
..tensor.clone()
|
||||
});
|
||||
}
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
}
|
||||
}
|
||||
|
||||
fn less_or_equal_update_outputs(node: &mut Node) {
|
||||
match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => {
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type: ElementType::Bool,
|
||||
..tensor.clone()
|
||||
});
|
||||
}
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
}
|
||||
}
|
||||
|
||||
fn reduce_mean_update_outputs(node: &mut Node) {
|
||||
if node.inputs.len() != 1 {
|
||||
panic!("Mean: multiple inputs are not supported");
|
||||
|
@ -455,18 +411,27 @@ fn squeeze_update_output(node: &mut Node) {
|
|||
});
|
||||
}
|
||||
|
||||
fn sub_update_outputs(node: &mut Node) {
|
||||
node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) {
|
||||
(ArgType::Scalar(_lhs), ArgType::Scalar(rhs)) => ArgType::Scalar(rhs),
|
||||
(ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs),
|
||||
(ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs),
|
||||
// Support broadcasting for lhs/rhs
|
||||
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim > rhs.dim => ArgType::Tensor(lhs),
|
||||
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim <= rhs.dim => ArgType::Tensor(rhs),
|
||||
_ => {
|
||||
panic!("Only tensor-scalar inputs are valid.");
|
||||
/// Updates the output type for operations that take more than one Tensor/Scalar of the a type
|
||||
/// and returns the same type while supporting broadcasting
|
||||
///
|
||||
/// This is mostly the elementwise math operations, i.e. add, sub, mul, max etc.
|
||||
fn same_as_input_broadcast(node: &mut Node) {
|
||||
if node.inputs.iter().all(|input| input.ty.is_scalar()) {
|
||||
// If all inputs are scalar, the output is a scalar as well
|
||||
node.outputs[0].ty = node.inputs[0].ty.clone();
|
||||
} else {
|
||||
// else, if any input is a Tensor, use it's datatype,
|
||||
// or the input consists solely from Scalars and Shapes,
|
||||
// which should result in another Shape
|
||||
node.outputs[0].ty = node
|
||||
.inputs
|
||||
.iter()
|
||||
.find(|input| input.ty.is_tensor())
|
||||
.map(|input| input.ty.clone())
|
||||
.unwrap_or_else(|| ArgType::Shape(0)); //Shape dim will be set by broadcast calculation
|
||||
|
||||
set_broadcasting_output_shape(node);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Update the output tensor dimension based on the "axes" attribute or the second input
|
||||
|
@ -519,21 +484,34 @@ fn temporary_pass_through_stub(node: &mut Node) {
|
|||
node.outputs[0].ty = node.inputs[0].ty.clone();
|
||||
}
|
||||
|
||||
fn equal_update_outputs(node: &mut Node) {
|
||||
let input1_type = node.inputs[0].ty.clone();
|
||||
/// Sets the output for binary operators resulting in a boolean output,
|
||||
/// i.e., comparison operators like Equal, Greater, Less, etc.
|
||||
///
|
||||
/// Support for broadcasting is assumed
|
||||
fn elementwise_comparsion_outputs(node: &mut Node) {
|
||||
let input1_type = &node.inputs[0].ty;
|
||||
let input2_type = &node.inputs[1].ty;
|
||||
|
||||
match input1_type {
|
||||
ArgType::Tensor(tensor) => {
|
||||
// if the input is a tensor, the output is a tensor of bool
|
||||
match (input1_type, input2_type) {
|
||||
(ArgType::Tensor(tensor), _) | (_, ArgType::Tensor(tensor)) => {
|
||||
// if one of the inputs is a tensor, the output is a tensor of bool
|
||||
assert_ne!(
|
||||
tensor.dim, 0,
|
||||
"Got a rank 0 Tensor, that should have been a Scalar!"
|
||||
);
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type: ElementType::Bool,
|
||||
..tensor
|
||||
..tensor.clone()
|
||||
});
|
||||
set_broadcasting_output_shape(node);
|
||||
}
|
||||
ArgType::Scalar(_) => {
|
||||
(ArgType::Scalar(_), ArgType::Scalar(_)) => {
|
||||
// if both inputs are scalars, the result is a scalar bool
|
||||
node.outputs[0].ty = ArgType::Scalar(ElementType::Bool);
|
||||
}
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
_ => {
|
||||
panic!("Invalid input types for comparison op: {input1_type:?}, {input2_type:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -788,11 +766,7 @@ fn reduce_sum_update_outputs(node: &mut Node) {
|
|||
}
|
||||
|
||||
fn where_update_outputs(node: &mut Node) {
|
||||
match (
|
||||
node.inputs[0].ty.clone(),
|
||||
node.inputs[1].ty.clone(),
|
||||
node.inputs[2].ty.clone(),
|
||||
) {
|
||||
match (&node.inputs[0].ty, &node.inputs[1].ty, &node.inputs[2].ty) {
|
||||
(ArgType::Tensor(condition), ArgType::Tensor(x), ArgType::Tensor(y)) => {
|
||||
// With broadcasting support, output dim has to be computed based on the inputs
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
|
@ -800,6 +774,7 @@ fn where_update_outputs(node: &mut Node) {
|
|||
dim: max(condition.dim, max(x.dim, y.dim)),
|
||||
..Default::default()
|
||||
});
|
||||
set_broadcasting_output_shape(node);
|
||||
}
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
}
|
||||
|
@ -841,3 +816,68 @@ fn gather_update_outputs(node: &mut Node) {
|
|||
ty => panic!("Only tensor/shape input is valid but received: {:?}", ty),
|
||||
}
|
||||
}
|
||||
|
||||
/// If all input shapes are known,
|
||||
/// calculates the rank and shape of the output tensor for Operators supporting
|
||||
/// [broadcasting](https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md)
|
||||
fn set_broadcasting_output_shape(node: &mut Node) {
|
||||
let mut reverse_out_shape: Vec<usize> = vec![1];
|
||||
for (idx, input_type) in node.inputs.iter().enumerate() {
|
||||
match &input_type.ty {
|
||||
ArgType::Tensor(t) => {
|
||||
if let Some(shape) = &t.shape {
|
||||
for (rev_idx, dimension) in shape.iter().rev().enumerate() {
|
||||
if let Some(current_out_dim) = reverse_out_shape.get_mut(rev_idx) {
|
||||
if *dimension == 1 {
|
||||
// dimension already has a value, this tensor can be broadcast
|
||||
continue;
|
||||
}
|
||||
if current_out_dim != dimension && *current_out_dim != 1 {
|
||||
panic!("Invalid shape for broadcasting - the dimension from the {rev_idx}. to last position has conflicting values {current_out_dim} and {dimension} from different inputs");
|
||||
}
|
||||
*current_out_dim = *dimension;
|
||||
} else {
|
||||
reverse_out_shape.push(*dimension);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("Input {idx} has no known shape, cannot predict broadcast result shape");
|
||||
return;
|
||||
}
|
||||
}
|
||||
ArgType::Scalar(_) => {
|
||||
// reverse_out_shape already starts with [1]
|
||||
}
|
||||
ArgType::Shape(s) => {
|
||||
// Shape is treated like a 1-D Tensor
|
||||
let current_out_dim = &mut reverse_out_shape[0];
|
||||
if *current_out_dim != 1 && *current_out_dim != *s {
|
||||
panic!("Invalid shape for broadcasting - the last position has conflicting values {current_out_dim} and {} from different inputs", s);
|
||||
}
|
||||
*current_out_dim = *s;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we get to this point without returning, reverse_out_shape will be final
|
||||
let mut out_shape = reverse_out_shape;
|
||||
out_shape.reverse(); //unreverse it
|
||||
|
||||
match &mut node.outputs[0].ty {
|
||||
ArgType::Tensor(t) => {
|
||||
t.dim = out_shape.len();
|
||||
t.shape = Some(out_shape);
|
||||
}
|
||||
ArgType::Scalar(_) => {
|
||||
if out_shape.len() > 1 || out_shape[0] > 1 {
|
||||
panic!("Output is a Scalar, but broadcasting results in tensor shape {out_shape:?}")
|
||||
}
|
||||
}
|
||||
ArgType::Shape(s) => {
|
||||
if out_shape.len() > 1 {
|
||||
panic!("Output is a Shape, but broadcasting results in higher-rank tensor shape {out_shape:?}")
|
||||
}
|
||||
*s = out_shape[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -128,6 +128,15 @@ impl Default for ArgType {
|
|||
}
|
||||
}
|
||||
|
||||
impl ArgType {
|
||||
pub fn is_scalar(&self) -> bool {
|
||||
matches!(self, Self::Scalar(_))
|
||||
}
|
||||
pub fn is_tensor(&self) -> bool {
|
||||
matches!(self, Self::Tensor(_))
|
||||
}
|
||||
}
|
||||
|
||||
impl Argument {
|
||||
pub fn new(name: String) -> Self {
|
||||
Self {
|
||||
|
|
|
@ -242,6 +242,11 @@ impl TryFrom<ValueInfoProto> for Argument {
|
|||
}
|
||||
};
|
||||
|
||||
let ty = if tensor_proto.shape.dim.is_empty() {
|
||||
// tensor_proto describes a scalar
|
||||
ArgType::Scalar(elem_type)
|
||||
} else {
|
||||
// tensor_proto describes a tensor
|
||||
let tensor_type = TensorType {
|
||||
dim: tensor_proto.shape.dim.len(),
|
||||
elem_type,
|
||||
|
@ -255,7 +260,8 @@ impl TryFrom<ValueInfoProto> for Argument {
|
|||
),
|
||||
};
|
||||
|
||||
let ty = ArgType::Tensor(tensor_type);
|
||||
ArgType::Tensor(tensor_type)
|
||||
};
|
||||
|
||||
Ok(Argument {
|
||||
ty,
|
||||
|
|
Loading…
Reference in New Issue