Fix indices dim check in gather_update_outputs (#2149)

This commit is contained in:
Guillaume Lagrange 2024-08-12 09:20:25 -04:00 committed by GitHub
parent 12caca7909
commit 0eec293e28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 6 deletions

View File

@ -810,19 +810,20 @@ fn gather_update_outputs(node: &mut Node) {
panic!("Gather requires two inputs: data and indices");
}
let indices_tensor = match &node.inputs[1].ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor indices is valid"),
let indices_dim = match &node.inputs[1].ty {
ArgType::Tensor(tensor) => tensor.dim,
ArgType::Scalar(_) => 0,
_ => panic!("Only tensor indices is valid, got {:?}", node.inputs[1].ty),
};
if indices_tensor.dim > 1 {
if indices_dim > 1 {
panic!("Gather: indices tensor rank above 1 not supported")
}
match &node.inputs[0].ty {
ArgType::Tensor(input_tensor) => {
// Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input
let output_rank = indices_tensor.dim + input_tensor.dim - 1;
let output_rank = indices_dim + input_tensor.dim - 1;
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: input_tensor.elem_type.clone(),
@ -833,7 +834,7 @@ fn gather_update_outputs(node: &mut Node) {
ArgType::Shape(_dim) => {
let shape_dim = 1;
// Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input
let output_rank = indices_tensor.dim + shape_dim - 1;
let output_rank = indices_dim + shape_dim - 1;
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Int64,