Several fixes and improvements

This commit is contained in:
Lysandre 2019-10-30 20:25:32 +00:00 committed by Lysandre Debut
parent ee20201d33
commit 1e5b31c388
3 changed files with 19 additions and 19 deletions

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers.configuration_albert import AlbertConfig
from transformers.modeling_bert import BertEmbeddings, BertModel, BertSelfAttention, prune_linear_layer, gelu_new
from transformers.modeling_utils import PreTrainedModel
from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
@ -37,18 +38,17 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
print(name)
for name, array in zip(names, arrays):
print(name)
og = name
original_name = name
name = name.replace("ffn_1", "ffn")
name = name.replace("ffn/intermediate/output", "ffn_output")
name = name.replace("attention_1", "attention")
name = name.replace("cls/predictions/transform", "predictions")
name = name.replace("LayerNorm_1", "attention/LayerNorm")
name = name.replace("cls/predictions", "predictions")
name = name.replace("transform/", "")
name = name.replace("LayerNorm_1", "full_layer_layer_norm")
name = name.replace("LayerNorm", "attention/LayerNorm")
name = name.replace("inner_group_", "albert_layers/")
name = name.replace("group_", "albert_layer_groups/")
name = name.split('/')
print(name)
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
@ -78,13 +78,12 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
print("transposed")
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {} from {}".format(name, og))
print("Initialize PyTorch weight {} from {}".format(name, original_name))
pointer.data = torch.from_numpy(array)
return model
@ -177,9 +176,9 @@ class AlbertAttention(BertSelfAttention):
b = self.dense.bias
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
projected_context_layer = self.dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer)
return layernormed_context_layer, projected_context_layer, reshaped_context_layer, context_layer, attention_scores, attention_probs, attention_mask
projected_context_layer_dropout = self.dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer_dropout)
return layernormed_context_layer
class AlbertLayer(nn.Module):
@ -187,17 +186,17 @@ class AlbertLayer(nn.Module):
super(AlbertLayer, self).__init__()
self.config = config
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = AlbertAttention(config)
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states, attention_mask=None, head_mask=None):
attention_output = self.attention(hidden_states, attention_mask)[0]
attention_output = self.attention(hidden_states, attention_mask)
ffn_output = self.ffn(attention_output)
ffn_output = gelu_new(ffn_output)
ffn_output = self.ffn_output(ffn_output)
hidden_states = self.LayerNorm(ffn_output + attention_output)
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
return hidden_states
@ -352,16 +351,17 @@ class AlbertModel(BertModel):
encoder_outputs = self.encoder(embedding_output,
extended_attention_mask,
head_mask=head_mask)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
outputs = (sequence_output, pooled_output) + encoder_outputs[1:] # add hidden_states and attentions if they are here
return outputs
@add_start_docstrings("Bert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
class AlbertForMaskedLM(nn.Module):
class AlbertForMaskedLM(PreTrainedModel):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
@ -384,7 +384,7 @@ class AlbertForMaskedLM(nn.Module):
"""
def __init__(self, config):
super(AlbertForMaskedLM, self).__init__()
super(AlbertForMaskedLM, self).__init__(config)
self.config = config
self.bert = AlbertModel(config)

View File

@ -8,7 +8,7 @@ from shutil import copyfile
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': '30k-clean.model'}
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
SPIECE_UNDERLINE = u''
class AlbertTokenizer(PreTrainedTokenizer):