Fix reshape bug (support for opset version 1) (#1667)

* Make reshape op version 1

* Refactor per PR feedback
This commit is contained in:
Dilshod Tadjibaev 2024-04-22 17:52:25 -05:00 committed by GitHub
parent 29fa2ee76c
commit 1718da5210
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 26 additions and 45 deletions

View File

@ -194,30 +194,27 @@ fn concat_update_outputs(node: &mut Node) {
node.outputs[0].ty = ArgType::Tensor(tensor.clone());
}
fn reshape_update_outputs(node: &mut Node) {
assert_eq!(node.inputs.len(), 2);
let shape = if let Some(Data::Int64s(ref shape)) = node.inputs[1].value {
shape
} else {
panic!("Reshape: int64s shape is expected per ONNX spec");
let shape = match node.inputs.get(1) {
Some(input) => match &input.value {
Some(Data::Int64s(shape)) => Some(shape.clone()),
_ => panic!("Reshape: invalid input types"),
},
None => node.attrs.get("shape").cloned().map(|v| v.into_i64s()),
};
// The output dimension is the same as the shape length
let dim = shape.len();
let elem_type = match node.inputs[0].ty.clone() {
ArgType::Tensor(tensor) => tensor.elem_type,
_ => panic!("Reshape: invalid input type"),
let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Reshape: invalid output types"),
};
let shape = shape.iter().map(|&dim| dim as usize).collect();
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type,
dim,
shape: Some(shape),
});
if let Some(shape) = shape {
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: shape.len(),
shape: None, // shape is calculated at runtime
..output
});
}
}
fn reduce_mean_update_outputs(node: &mut Node) {
@ -254,40 +251,24 @@ fn reduce_mean_update_outputs(node: &mut Node) {
/// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {
// get the values while making sure the types are correct
match &node.inputs[1].value {
Some(value) => match value {
Data::Int64s(axes) => Some(axes.clone()),
_ => panic!("Unsqueeze: invalid input types"),
},
None => None,
}
} else {
node.attrs
.iter()
.find_map(|(key, value)| match key.as_str() {
"axes" => Some(value.clone().into_i64s()),
_ => None,
})
let axes = match node.inputs.get(1) {
Some(input) => match &input.value {
Some(Data::Int64s(axes)) => Some(axes.clone()),
_ => panic!("Unsqueeze: invalid input types"),
},
None => node.attrs.get("axes").cloned().map(|v| v.into_i64s()),
};
// need output way up here to avoid borrowing issues
let input = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Unsqueeze: invalid output types"),
ty => panic!("Unsqueeze: invalid output type ({ty:?})"),
};
let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Unsqueeze: invalid output types"),
};
if axes.is_some() {
if let Some(axes) = axes {
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: input.dim + axes.unwrap().len(),
dim: input.dim + axes.len(),
shape: None, // shape is calculated at runtime
..output
..input
});
}
}