Fix exception in prediction loop occurring for certain batch sizes (#12350)
* fix distributed_concat for scalar outputs * Update README.md * fixed typo (#12356) * simplify fix with terser syntax Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Trigger CI Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: michal pitr <21157924+MichalPitr@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
d4ce31e839
commit
55bb4c06f7
|
@ -155,6 +155,7 @@ def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int]
|
|||
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
||||
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output_tensors, tensor)
|
||||
output_tensors = [t if len(t.shape) > 0 else t[None] for t in output_tensors]
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
|
|
Loading…
Reference in New Issue