add fallback path for apex used in modeling.py

This commit is contained in:
Deyu Fu 2018-12-12 15:05:45 -08:00
parent c8ea286048
commit 3b0a14b761
1 changed files with 22 additions and 27 deletions

View File

@ -31,10 +31,6 @@ import shutil
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.")
from .file_utils import cached_path
@ -157,22 +153,24 @@ class BertConfig(object):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
class BertLayerNorm(nn.Module):
def __init__(self, config, variance_epsilon=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(config.hidden_size))
self.beta = nn.Parameter(torch.zeros(config.hidden_size))
self.variance_epsilon = variance_epsilon
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.gamma * x + self.beta
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
@ -185,7 +183,7 @@ class BertEmbeddings(nn.Module):
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
@ -260,7 +258,7 @@ class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
@ -299,7 +297,7 @@ class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
@ -361,7 +359,7 @@ class BertPredictionHeadTransform(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
@ -443,12 +441,9 @@ class PreTrainedBertModel(nn.Module):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, FusedLayerNorm):
elif isinstance(module, BertLayerNorm):
module.bias.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.beta.data.normal_(mean=0.0, std=self.config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()