TorchScript flag in config; Tied weights when not running TorchScript; tuple concatenation clean-up.
This commit is contained in:
parent
4703148f0c
commit
b43b130f35
|
@ -46,6 +46,7 @@ class PretrainedConfig(object):
|
|||
self.num_labels = kwargs.pop('num_labels', 2)
|
||||
self.output_attentions = kwargs.pop('output_attentions', False)
|
||||
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
||||
self.torchscript = kwargs.pop('torchscript', False)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
|
|
|
@ -428,23 +428,23 @@ class BertEncoder(nn.Module):
|
|||
all_attentions = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_hidden_states:
|
||||
outputs += (all_hidden_states,)
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
outputs += (all_attentions,)
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # outputs, (hidden states), (attentions)
|
||||
|
||||
|
||||
|
@ -484,13 +484,19 @@ class BertLMPredictionHead(nn.Module):
|
|||
def __init__(self, config, bert_model_embedding_weights):
|
||||
super(BertLMPredictionHead, self).__init__()
|
||||
self.transform = BertPredictionHeadTransform(config)
|
||||
self.torchscript = config.torchscript
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
||||
bert_model_embedding_weights.size(0),
|
||||
bias=False)
|
||||
self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone())
|
||||
|
||||
if self.torchscript:
|
||||
self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone())
|
||||
else:
|
||||
self.decoder.weight = bert_model_embedding_weights
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
|
|
@ -322,6 +322,7 @@ class GPT2LMHead(nn.Module):
|
|||
self.n_embd = config.n_embd
|
||||
self.vocab_size = config.vocab_size
|
||||
self.predict_special_tokens = config.predict_special_tokens
|
||||
self.torchscript = config.torchscript
|
||||
embed_shape = model_embeddings_weights.shape
|
||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||
self.set_embeddings_weights(model_embeddings_weights)
|
||||
|
@ -329,7 +330,10 @@ class GPT2LMHead(nn.Module):
|
|||
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
|
||||
self.predict_special_tokens = predict_special_tokens
|
||||
# Export to TorchScript can't handle parameter sharing so we are cloning them.
|
||||
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone()) # Tied weights
|
||||
if self.torchscript:
|
||||
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone())
|
||||
else:
|
||||
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||
|
||||
def forward(self, hidden_state):
|
||||
lm_logits = self.decoder(hidden_state)
|
||||
|
@ -563,11 +567,11 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||
all_hidden_states = ()
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (hidden_states.view(*output_shape),)
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
|
||||
outputs = block(hidden_states, layer_past, head_mask[i])
|
||||
hidden_states, present = outputs[:2]
|
||||
presents += (present,)
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions.append(outputs[2])
|
||||
|
@ -577,16 +581,16 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||
hidden_states = hidden_states.view(*output_shape)
|
||||
# Add last hidden state
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states, presents)
|
||||
if self.output_hidden_states:
|
||||
outputs += (all_hidden_states,)
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
# let the number of heads free (-1) so we can extract attention even after head pruning
|
||||
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
|
||||
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
|
||||
outputs += (all_attentions,)
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # last hidden state, presents, (all hidden_states), (attentions)
|
||||
|
||||
|
||||
|
|
|
@ -348,14 +348,18 @@ class OpenAIGPTLMHead(nn.Module):
|
|||
self.n_embd = config.n_embd
|
||||
self.vocab_size = config.vocab_size
|
||||
self.predict_special_tokens = config.predict_special_tokens
|
||||
self.torchscript = config.torchscript
|
||||
embed_shape = model_embeddings_weights.shape
|
||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||
self.set_embeddings_weights(model_embeddings_weights)
|
||||
|
||||
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
|
||||
self.predict_special_tokens = predict_special_tokens
|
||||
embed_shape = model_embeddings_weights.shape
|
||||
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone()) # Tied weights
|
||||
|
||||
if self.torchscript:
|
||||
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone())
|
||||
else:
|
||||
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||
|
||||
def forward(self, hidden_state):
|
||||
lm_logits = self.decoder(hidden_state)
|
||||
|
@ -583,22 +587,22 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||
all_hidden_states = ()
|
||||
for i, block in enumerate(self.h):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (hidden_states.view(*output_shape),)
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
|
||||
outputs = block(hidden_states, head_mask[i])
|
||||
hidden_states = outputs[0]
|
||||
if self.output_attentions:
|
||||
all_attentions += (outputs[1],)
|
||||
all_attentions = all_attentions + (outputs[1],)
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (hidden_states.view(*output_shape),)
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
|
||||
outputs = (hidden_states.view(*output_shape),)
|
||||
if self.output_hidden_states:
|
||||
outputs += (all_hidden_states,)
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
outputs += (all_attentions,)
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # last hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
|
|
|
@ -530,7 +530,7 @@ class XLNetRelativeAttention(nn.Module):
|
|||
|
||||
outputs = (output_h, output_g)
|
||||
if self.output_attentions:
|
||||
outputs += (attn_prob,)
|
||||
outputs = outputs + (attn_prob,)
|
||||
return outputs
|
||||
|
||||
class XLNetFeedForward(nn.Module):
|
||||
|
@ -878,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||
hidden_states = []
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
# cache new mems
|
||||
new_mems += (self.cache_mem(output_h, mems[i]),)
|
||||
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
||||
if self.output_hidden_states:
|
||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||
|
||||
|
@ -902,10 +902,10 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||
hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
|
||||
else:
|
||||
hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
|
||||
outputs += (hidden_states,)
|
||||
outputs = outputs + (hidden_states,)
|
||||
if self.output_attentions:
|
||||
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
|
||||
outputs += (attentions,)
|
||||
outputs = outputs + (attentions,)
|
||||
|
||||
return outputs # outputs, new_mems, (hidden_states), (attentions)
|
||||
|
||||
|
@ -975,6 +975,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||
super(XLNetLMHeadModel, self).__init__(config)
|
||||
self.attn_type = config.attn_type
|
||||
self.same_length = config.same_length
|
||||
self.torchscript = config.torchscript
|
||||
|
||||
self.transformer = XLNetModel(config)
|
||||
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
|
||||
|
@ -987,7 +988,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||
def tie_weights(self):
|
||||
""" Make sure we are sharing the embeddings
|
||||
"""
|
||||
self.lm_loss.weight = nn.Parameter(self.transformer.word_embedding.weight.clone())
|
||||
if self.torchscript:
|
||||
self.lm_loss.weight = nn.Parameter(self.transformer.word_embedding.weight.clone())
|
||||
else:
|
||||
self.lm_loss.weight = self.transformer.word_embedding.weight
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
|
||||
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
|
||||
|
|
|
@ -41,6 +41,7 @@ def _create_and_check_torchscript_output_hidden_state(tester, model_classes, con
|
|||
|
||||
def _create_and_check_torchscript(tester, model_classes, config, inputs_dict):
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.torchscript = True
|
||||
for model_class in model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.eval()
|
||||
|
|
Loading…
Reference in New Issue