Rm duplicate pad_across_processes (#24780)

Rm duplicate
This commit is contained in:
Zach Mueller 2023-07-12 11:47:21 -04:00 committed by GitHub
parent cfc8a05305
commit f1732e1374
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 0 additions and 35 deletions

View File

@ -3257,41 +3257,6 @@ class Trainer:
tensors = distributed_concat(tensors)
return tensors
# Copied from Accelerate.
def _pad_across_processes(self, tensor, pad_index=-100):
"""
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
they can safely be gathered.
"""
if isinstance(tensor, (list, tuple)):
return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
elif isinstance(tensor, dict):
return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
elif not isinstance(tensor, torch.Tensor):
raise TypeError(
f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
)
if len(tensor.shape) < 2:
return tensor
# Gather all sizes
size = torch.tensor(tensor.shape, device=tensor.device)[None]
sizes = self._nested_gather(size).cpu()
max_size = max(s[1] for s in sizes)
# When extracting XLA graphs for compilation, max_size is 0,
# so use inequality to avoid errors.
if tensor.shape[1] >= max_size:
return tensor
# Then pad to the maximum size
old_size = tensor.shape
new_size = list(old_size)
new_size[1] = max_size
new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
new_tensor[:, : old_size[1]] = tensor
return new_tensor
def prediction_step(
self,
model: nn.Module,