TorchScript flag in config; Tied weights when not running TorchScript; tuple concatenation clean-up.

This commit is contained in:
LysandreJik 2019-07-03 16:21:17 -04:00
parent 4703148f0c
commit b43b130f35
6 changed files with 44 additions and 24 deletions

View File

@ -46,6 +46,7 @@ class PretrainedConfig(object):
self.num_labels = kwargs.pop('num_labels', 2)
self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.torchscript = kwargs.pop('torchscript', False)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):

View File

@ -428,23 +428,23 @@ class BertEncoder(nn.Module):
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states += (hidden_states,)
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
hidden_states = layer_outputs[0]
if self.output_attentions:
all_attentions += (layer_outputs[1],)
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if self.output_hidden_states:
all_hidden_states += (hidden_states,)
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs += (all_hidden_states,)
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs += (all_attentions,)
outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions)
@ -484,13 +484,19 @@ class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
self.torchscript = config.torchscript
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
bert_model_embedding_weights.size(0),
bias=False)
self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone())
if self.torchscript:
self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone())
else:
self.decoder.weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
def forward(self, hidden_states):

View File

@ -322,6 +322,7 @@ class GPT2LMHead(nn.Module):
self.n_embd = config.n_embd
self.vocab_size = config.vocab_size
self.predict_special_tokens = config.predict_special_tokens
self.torchscript = config.torchscript
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.set_embeddings_weights(model_embeddings_weights)
@ -329,7 +330,10 @@ class GPT2LMHead(nn.Module):
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
self.predict_special_tokens = predict_special_tokens
# Export to TorchScript can't handle parameter sharing so we are cloning them.
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone()) # Tied weights
if self.torchscript:
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone())
else:
self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state):
lm_logits = self.decoder(hidden_state)
@ -563,11 +567,11 @@ class GPT2Model(GPT2PreTrainedModel):
all_hidden_states = ()
for i, (block, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states += (hidden_states.view(*output_shape),)
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block(hidden_states, layer_past, head_mask[i])
hidden_states, present = outputs[:2]
presents += (present,)
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
@ -577,16 +581,16 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states = hidden_states.view(*output_shape)
# Add last hidden state
if self.output_hidden_states:
all_hidden_states += (hidden_states,)
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents)
if self.output_hidden_states:
outputs += (all_hidden_states,)
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs += (all_attentions,)
outputs = outputs + (all_attentions,)
return outputs # last hidden state, presents, (all hidden_states), (attentions)

View File

@ -348,14 +348,18 @@ class OpenAIGPTLMHead(nn.Module):
self.n_embd = config.n_embd
self.vocab_size = config.vocab_size
self.predict_special_tokens = config.predict_special_tokens
self.torchscript = config.torchscript
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
self.predict_special_tokens = predict_special_tokens
embed_shape = model_embeddings_weights.shape
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone()) # Tied weights
if self.torchscript:
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone())
else:
self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state):
lm_logits = self.decoder(hidden_state)
@ -583,22 +587,22 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
all_hidden_states = ()
for i, block in enumerate(self.h):
if self.output_hidden_states:
all_hidden_states += (hidden_states.view(*output_shape),)
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block(hidden_states, head_mask[i])
hidden_states = outputs[0]
if self.output_attentions:
all_attentions += (outputs[1],)
all_attentions = all_attentions + (outputs[1],)
# Add last layer
if self.output_hidden_states:
all_hidden_states += (hidden_states.view(*output_shape),)
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = (hidden_states.view(*output_shape),)
if self.output_hidden_states:
outputs += (all_hidden_states,)
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs += (all_attentions,)
outputs = outputs + (all_attentions,)
return outputs # last hidden state, (all hidden states), (all attentions)

View File

@ -530,7 +530,7 @@ class XLNetRelativeAttention(nn.Module):
outputs = (output_h, output_g)
if self.output_attentions:
outputs += (attn_prob,)
outputs = outputs + (attn_prob,)
return outputs
class XLNetFeedForward(nn.Module):
@ -878,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states = []
for i, layer_module in enumerate(self.layer):
# cache new mems
new_mems += (self.cache_mem(output_h, mems[i]),)
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
@ -902,10 +902,10 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
else:
hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
outputs += (hidden_states,)
outputs = outputs + (hidden_states,)
if self.output_attentions:
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs += (attentions,)
outputs = outputs + (attentions,)
return outputs # outputs, new_mems, (hidden_states), (attentions)
@ -975,6 +975,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
super(XLNetLMHeadModel, self).__init__(config)
self.attn_type = config.attn_type
self.same_length = config.same_length
self.torchscript = config.torchscript
self.transformer = XLNetModel(config)
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
@ -987,7 +988,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def tie_weights(self):
""" Make sure we are sharing the embeddings
"""
self.lm_loss.weight = nn.Parameter(self.transformer.word_embedding.weight.clone())
if self.torchscript:
self.lm_loss.weight = nn.Parameter(self.transformer.word_embedding.weight.clone())
else:
self.lm_loss.weight = self.transformer.word_embedding.weight
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,

View File

@ -41,6 +41,7 @@ def _create_and_check_torchscript_output_hidden_state(tester, model_classes, con
def _create_and_check_torchscript(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
for model_class in model_classes:
model = model_class(config=configs_no_init)
model.eval()