mirror of https://github.com/tracel-ai/burn.git
Fix reshape bug (support for opset version 1) (#1667)
* Make reshape op version 1 * Refactor per PR feedback
This commit is contained in:
parent
29fa2ee76c
commit
1718da5210
|
@ -194,30 +194,27 @@ fn concat_update_outputs(node: &mut Node) {
|
|||
|
||||
node.outputs[0].ty = ArgType::Tensor(tensor.clone());
|
||||
}
|
||||
|
||||
fn reshape_update_outputs(node: &mut Node) {
|
||||
assert_eq!(node.inputs.len(), 2);
|
||||
|
||||
let shape = if let Some(Data::Int64s(ref shape)) = node.inputs[1].value {
|
||||
shape
|
||||
} else {
|
||||
panic!("Reshape: int64s shape is expected per ONNX spec");
|
||||
let shape = match node.inputs.get(1) {
|
||||
Some(input) => match &input.value {
|
||||
Some(Data::Int64s(shape)) => Some(shape.clone()),
|
||||
_ => panic!("Reshape: invalid input types"),
|
||||
},
|
||||
None => node.attrs.get("shape").cloned().map(|v| v.into_i64s()),
|
||||
};
|
||||
|
||||
// The output dimension is the same as the shape length
|
||||
let dim = shape.len();
|
||||
let elem_type = match node.inputs[0].ty.clone() {
|
||||
ArgType::Tensor(tensor) => tensor.elem_type,
|
||||
_ => panic!("Reshape: invalid input type"),
|
||||
let output = match &node.outputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor.clone(),
|
||||
_ => panic!("Reshape: invalid output types"),
|
||||
};
|
||||
|
||||
let shape = shape.iter().map(|&dim| dim as usize).collect();
|
||||
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
elem_type,
|
||||
dim,
|
||||
shape: Some(shape),
|
||||
});
|
||||
if let Some(shape) = shape {
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
dim: shape.len(),
|
||||
shape: None, // shape is calculated at runtime
|
||||
..output
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn reduce_mean_update_outputs(node: &mut Node) {
|
||||
|
@ -254,40 +251,24 @@ fn reduce_mean_update_outputs(node: &mut Node) {
|
|||
|
||||
/// Update the output tensor dimension based on the "axes" attribute or the second input
|
||||
fn unsqueeze_update_output(node: &mut Node) {
|
||||
let axes = if node.inputs.len() == 2 {
|
||||
// get the values while making sure the types are correct
|
||||
match &node.inputs[1].value {
|
||||
Some(value) => match value {
|
||||
Data::Int64s(axes) => Some(axes.clone()),
|
||||
_ => panic!("Unsqueeze: invalid input types"),
|
||||
},
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
node.attrs
|
||||
.iter()
|
||||
.find_map(|(key, value)| match key.as_str() {
|
||||
"axes" => Some(value.clone().into_i64s()),
|
||||
_ => None,
|
||||
})
|
||||
let axes = match node.inputs.get(1) {
|
||||
Some(input) => match &input.value {
|
||||
Some(Data::Int64s(axes)) => Some(axes.clone()),
|
||||
_ => panic!("Unsqueeze: invalid input types"),
|
||||
},
|
||||
None => node.attrs.get("axes").cloned().map(|v| v.into_i64s()),
|
||||
};
|
||||
|
||||
// need output way up here to avoid borrowing issues
|
||||
let input = match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor.clone(),
|
||||
_ => panic!("Unsqueeze: invalid output types"),
|
||||
ty => panic!("Unsqueeze: invalid output type ({ty:?})"),
|
||||
};
|
||||
|
||||
let output = match &node.outputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor.clone(),
|
||||
_ => panic!("Unsqueeze: invalid output types"),
|
||||
};
|
||||
|
||||
if axes.is_some() {
|
||||
if let Some(axes) = axes {
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
dim: input.dim + axes.unwrap().len(),
|
||||
dim: input.dim + axes.len(),
|
||||
shape: None, // shape is calculated at runtime
|
||||
..output
|
||||
..input
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue