Fix exception in prediction loop occurring for certain batch sizes (#12350)

* fix distributed_concat for scalar outputs

* Update README.md

* fixed typo (#12356)

* simplify fix with terser syntax

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Trigger CI

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: michal pitr <21157924+MichalPitr@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
jglaser 2021-06-25 10:55:15 -04:00 committed by GitHub
parent d4ce31e839
commit 55bb4c06f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 0 deletions

View File

@ -155,6 +155,7 @@ def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int]
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensor)
output_tensors = [t if len(t.shape) > 0 else t[None] for t in output_tensors]
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler