diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index f612a4e87d..e0b4fea0e6 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -359,6 +359,14 @@ def is_torch_fp16_available_on_device(device): try: x = torch.zeros(2, 2, dtype=torch.float16).to(device) _ = x @ x + + # At this moment, let's be strict of the check: check if `LayerNorm` is also supported on device, because many + # models use this layer. + batch, sentence_length, embedding_dim = 3, 4, 5 + embedding = torch.randn(batch, sentence_length, embedding_dim, dtype=torch.float16, device=device) + layer_norm = torch.nn.LayerNorm(embedding_dim, dtype=torch.float16, device=device) + _ = layer_norm(embedding) + except: # noqa: E722 # TODO: more precise exception matching, if possible. # most backends should return `RuntimeError` however this is not guaranteed.