Ensure tensors are at least 1d for pad and concat (#17179)
* Ensure tensors are at least 1d for pad and concat * Compatibility * Fix * Fix * Add test * Retrigger CI * Consistency with master * Retrigger CI
This commit is contained in:
parent
c76afa511c
commit
47412c7d43
|
@ -55,8 +55,22 @@ except ImportError:
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):
|
||||
if isinstance(tensor_or_array, torch.Tensor):
|
||||
if hasattr(torch, "atleast_1d"):
|
||||
tensor_or_array = torch.atleast_1d(tensor_or_array)
|
||||
elif tensor_or_array.ndim < 1:
|
||||
tensor_or_array = tensor_or_array[None]
|
||||
else:
|
||||
tensor_or_array = np.atleast_1d(tensor_or_array)
|
||||
return tensor_or_array
|
||||
|
||||
|
||||
def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
|
||||
"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
|
||||
tensor1 = atleast_1d(tensor1)
|
||||
tensor2 = atleast_1d(tensor2)
|
||||
|
||||
if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
|
||||
return torch.cat((tensor1, tensor2), dim=0)
|
||||
|
||||
|
@ -72,6 +86,9 @@ def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
|
|||
|
||||
def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
|
||||
"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
|
||||
array1 = atleast_1d(array1)
|
||||
array2 = atleast_1d(array2)
|
||||
|
||||
if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
|
||||
return np.concatenate((array1, array2), axis=0)
|
||||
|
||||
|
@ -149,8 +166,7 @@ def nested_xla_mesh_reduce(tensors, name):
|
|||
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
|
||||
if tensors.ndim == 0:
|
||||
tensors = tensors[None]
|
||||
tensors = atleast_1d(tensors)
|
||||
return xm.mesh_reduce(name, tensors, torch.cat)
|
||||
else:
|
||||
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
|
||||
|
@ -160,8 +176,7 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
|
|||
try:
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
||||
if len(tensor.shape) <= 0:
|
||||
tensor = tensor[None]
|
||||
tensor = atleast_1d(tensor)
|
||||
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output_tensors, tensor)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
@ -1031,7 +1046,7 @@ if is_sagemaker_mp_enabled():
|
|||
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
|
||||
)
|
||||
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
|
||||
all_tensors = [t if len(t.shape) > 0 else t[None] for t in all_tensors]
|
||||
all_tensors = [atleast_1d(t) for t in all_tensors]
|
||||
return torch.cat([t.cpu() for t in all_tensors], dim=0)
|
||||
|
||||
def smp_nested_concat(tensor):
|
||||
|
|
|
@ -41,6 +41,8 @@ if is_torch_available():
|
|||
SequentialDistributedSampler,
|
||||
ShardSampler,
|
||||
get_parameter_names,
|
||||
numpy_pad_and_concatenate,
|
||||
torch_pad_and_concatenate,
|
||||
)
|
||||
|
||||
class TstLayer(nn.Module):
|
||||
|
@ -459,6 +461,18 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||
mock_training_loop_function()
|
||||
self.assertEqual("CUDA out of memory", cm.args[0])
|
||||
|
||||
def test_pad_and_concatenate_with_1d(self):
|
||||
"""Tests whether pad_and_concatenate works with scalars."""
|
||||
array1 = 1.0
|
||||
array2 = 2.0
|
||||
result = numpy_pad_and_concatenate(array1, array2)
|
||||
self.assertTrue(np.array_equal(np.array([1.0, 2.0]), result))
|
||||
|
||||
tensor1 = torch.tensor(1.0)
|
||||
tensor2 = torch.tensor(2.0)
|
||||
result = torch_pad_and_concatenate(tensor1, tensor2)
|
||||
self.assertTrue(torch.equal(result, torch.Tensor([1.0, 2.0])))
|
||||
|
||||
def test_remove_columns_collator(self):
|
||||
class MockLogger:
|
||||
def __init__(self) -> None:
|
||||
|
|
Loading…
Reference in New Issue