From ae333d04b29a25be1a70eaccd6260c294c243c5b Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 30 Dec 2020 01:09:51 -0800 Subject: [PATCH] torch.cuda.is_available() is redundant as apex handles that internally (#9350) --- src/transformers/models/bart/modeling_bart.py | 11 +++++------ src/transformers/models/fsmt/modeling_fsmt.py | 12 +++++------- .../models/prophetnet/modeling_prophetnet.py | 11 +++++------ 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index f631736cc4..7f4af885d5 100644 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -110,13 +110,12 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] def BartLayerNorm(normalized_shape: torch.Size, eps: float = 1e-5, elementwise_affine: bool = True): - if torch.cuda.is_available(): - try: - from apex.normalization import FusedLayerNorm + try: + from apex.normalization import FusedLayerNorm - return FusedLayerNorm(normalized_shape, eps, elementwise_affine) - except ImportError: - pass + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + except ImportError: + pass return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 682d6af006..0cd07ed6e6 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -265,14 +265,12 @@ FSMT_INPUTS_DOCSTRING = r""" have_fused_layer_norm = False -if torch.cuda.is_available(): - try: - from apex.normalization import FusedLayerNorm - - have_fused_layer_norm = True - except ImportError: - pass +try: + from apex.normalization import FusedLayerNorm + have_fused_layer_norm = True +except ImportError: + pass LayerNorm = FusedLayerNorm if have_fused_layer_norm else torch.nn.LayerNorm diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index ffa669d64b..17db02b5b2 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -511,13 +511,12 @@ class ProphetNetDecoderLMOutput(ModelOutput): def ProphetNetLayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): - if torch.cuda.is_available(): - try: - from apex.normalization import FusedLayerNorm + try: + from apex.normalization import FusedLayerNorm - return FusedLayerNorm(normalized_shape, eps, elementwise_affine) - except ImportError: - pass + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + except ImportError: + pass return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)