Reshape bug fix (#1684)

* Revert 1c639c8393

1c639c8393?diff=unified&w=0

* Refactor by @laggui

* Refactor unsqueeze
This commit is contained in:
Dilshod Tadjibaev 2024-04-24 19:31:53 -05:00 committed by GitHub
parent 886a1de235
commit a1bd14c5ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 29 additions and 14 deletions

View File

@ -195,13 +195,18 @@ fn concat_update_outputs(node: &mut Node) {
node.outputs[0].ty = ArgType::Tensor(tensor.clone());
}
fn reshape_update_outputs(node: &mut Node) {
let shape = match node.inputs.get(1) {
Some(input) => match &input.value {
Some(Data::Int64s(shape)) => Some(shape.clone()),
_ => panic!("Reshape: invalid input types"),
},
None => node.attrs.get("shape").cloned().map(|v| v.into_i64s()),
let shape = if node.inputs.len() == 2 {
match &node.inputs[1].value {
Some(value) => match value {
Data::Int64s(shape) => Some(shape.clone()),
_ => panic!("Reshape: invalid input types"),
},
None => None,
}
} else {
node.attrs.get("shape").cloned().map(|v| v.into_i64s())
};
let output = match &node.outputs[0].ty {
@ -252,24 +257,34 @@ fn reduce_mean_update_outputs(node: &mut Node) {
/// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) {
let axes = match node.inputs.get(1) {
Some(input) => match &input.value {
Some(Data::Int64s(axes)) => Some(axes.clone()),
_ => panic!("Unsqueeze: invalid input types"),
},
None => node.attrs.get("axes").cloned().map(|v| v.into_i64s()),
let axes = if node.inputs.len() == 2 {
match &node.inputs[1].value {
Some(value) => match value {
Data::Int64s(axes) => Some(axes.clone()),
_ => panic!("Unsqueeze: invalid input types"),
},
None => None,
}
} else {
node.attrs.get("axes").cloned().map(|v| v.into_i64s())
};
// need output way up here to avoid borrowing issues
let input = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
ty => panic!("Unsqueeze: invalid output type ({ty:?})"),
_ => panic!("Unsqueeze: invalid output types"),
};
let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Unsqueeze: invalid output types"),
};
if let Some(axes) = axes {
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: input.dim + axes.len(),
shape: None, // shape is calculated at runtime
..input
..output
});
}
}