parent
cfc8a05305
commit
f1732e1374
|
@ -3257,41 +3257,6 @@ class Trainer:
|
|||
tensors = distributed_concat(tensors)
|
||||
return tensors
|
||||
|
||||
# Copied from Accelerate.
|
||||
def _pad_across_processes(self, tensor, pad_index=-100):
|
||||
"""
|
||||
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
|
||||
they can safely be gathered.
|
||||
"""
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
|
||||
elif isinstance(tensor, dict):
|
||||
return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
|
||||
elif not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
|
||||
)
|
||||
|
||||
if len(tensor.shape) < 2:
|
||||
return tensor
|
||||
# Gather all sizes
|
||||
size = torch.tensor(tensor.shape, device=tensor.device)[None]
|
||||
sizes = self._nested_gather(size).cpu()
|
||||
|
||||
max_size = max(s[1] for s in sizes)
|
||||
# When extracting XLA graphs for compilation, max_size is 0,
|
||||
# so use inequality to avoid errors.
|
||||
if tensor.shape[1] >= max_size:
|
||||
return tensor
|
||||
|
||||
# Then pad to the maximum size
|
||||
old_size = tensor.shape
|
||||
new_size = list(old_size)
|
||||
new_size[1] = max_size
|
||||
new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
|
||||
new_tensor[:, : old_size[1]] = tensor
|
||||
return new_tensor
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
|
|
Loading…
Reference in New Issue