mirror of https://github.com/tracel-ai/burn.git
Reshape bug fix (#1684)
* Revert1c639c8393
1c639c8393
?diff=unified&w=0 * Refactor by @laggui * Refactor unsqueeze
This commit is contained in:
parent
886a1de235
commit
a1bd14c5ae
|
@ -195,13 +195,18 @@ fn concat_update_outputs(node: &mut Node) {
|
|||
|
||||
node.outputs[0].ty = ArgType::Tensor(tensor.clone());
|
||||
}
|
||||
|
||||
fn reshape_update_outputs(node: &mut Node) {
|
||||
let shape = match node.inputs.get(1) {
|
||||
Some(input) => match &input.value {
|
||||
Some(Data::Int64s(shape)) => Some(shape.clone()),
|
||||
_ => panic!("Reshape: invalid input types"),
|
||||
},
|
||||
None => node.attrs.get("shape").cloned().map(|v| v.into_i64s()),
|
||||
let shape = if node.inputs.len() == 2 {
|
||||
match &node.inputs[1].value {
|
||||
Some(value) => match value {
|
||||
Data::Int64s(shape) => Some(shape.clone()),
|
||||
_ => panic!("Reshape: invalid input types"),
|
||||
},
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
node.attrs.get("shape").cloned().map(|v| v.into_i64s())
|
||||
};
|
||||
|
||||
let output = match &node.outputs[0].ty {
|
||||
|
@ -252,24 +257,34 @@ fn reduce_mean_update_outputs(node: &mut Node) {
|
|||
|
||||
/// Update the output tensor dimension based on the "axes" attribute or the second input
|
||||
fn unsqueeze_update_output(node: &mut Node) {
|
||||
let axes = match node.inputs.get(1) {
|
||||
Some(input) => match &input.value {
|
||||
Some(Data::Int64s(axes)) => Some(axes.clone()),
|
||||
_ => panic!("Unsqueeze: invalid input types"),
|
||||
},
|
||||
None => node.attrs.get("axes").cloned().map(|v| v.into_i64s()),
|
||||
let axes = if node.inputs.len() == 2 {
|
||||
match &node.inputs[1].value {
|
||||
Some(value) => match value {
|
||||
Data::Int64s(axes) => Some(axes.clone()),
|
||||
_ => panic!("Unsqueeze: invalid input types"),
|
||||
},
|
||||
None => None,
|
||||
}
|
||||
} else {
|
||||
node.attrs.get("axes").cloned().map(|v| v.into_i64s())
|
||||
};
|
||||
|
||||
// need output way up here to avoid borrowing issues
|
||||
let input = match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor.clone(),
|
||||
ty => panic!("Unsqueeze: invalid output type ({ty:?})"),
|
||||
_ => panic!("Unsqueeze: invalid output types"),
|
||||
};
|
||||
|
||||
let output = match &node.outputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor.clone(),
|
||||
_ => panic!("Unsqueeze: invalid output types"),
|
||||
};
|
||||
|
||||
if let Some(axes) = axes {
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
dim: input.dim + axes.len(),
|
||||
shape: None, // shape is calculated at runtime
|
||||
..input
|
||||
..output
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue