fix tests

This commit is contained in:
thomwolf 2019-06-26 10:02:45 +02:00
parent 092dacfd62
commit 93e9971c54
6 changed files with 38 additions and 149 deletions

View File

@ -930,12 +930,12 @@ all_hidden_states = lower_hidden_states + [hidden_states]
`TransfoXLLMHeadModel` includes the `TransfoXLModel` Transformer followed by an (adaptive) softmax head with weights tied to the input embeddings.
*Inputs* are the same as the inputs of the [`TransfoXLModel`](#-12.-`TransfoXLModel`) class plus optional labels:
- `target`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the target token indices selected in the range [0, self.config.n_token[
- `labels`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the labels token indices selected in the range [0, self.config.n_token[
*Outputs* a tuple of (last_hidden_state, new_mems)
- `softmax_output`: output of the (adaptive) softmax:
- if target is None: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
- else: Negative log likelihood of target tokens with shape [batch_size, sequence_length]
- if labels is None: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
- else: Negative log likelihood of labels tokens with shape [batch_size, sequence_length]
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
#### 14. `GPT2Model`

View File

@ -1025,14 +1025,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
`mems`: optional memomry of hidden states from previous forward passes
as a list (num layers) of hidden states at the entry of each layer
each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
Outputs:
A tuple of (last_hidden_state, new_mems)
`last_hidden_state`: the encoded-hidden-states at the top of the model
as a torch.FloatTensor of size [batch_size, sequence_length, self.config.d_model]
`new_mems`: list (num layers) of updated mem states at the entry of each layer
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
Example usage:
```python
@ -1265,7 +1265,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mems :: optional mems from previous forwar passes (or init_mems)
list (num layers) of mem states at the entry of each layer
shape :: [self.config.mem_len, bsz, self.config.d_model]
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
Returns:
tuple (last_hidden, new_mems) where:
new_mems: list (num layers) of mem states at the entry of each layer
@ -1303,23 +1303,23 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the token indices selected in the range [0, self.config.n_token[
`target`: an optional torch.LongTensor of shape [batch_size, sequence_length]
with the target token indices selected in the range [0, self.config.n_token[
`labels`: an optional torch.LongTensor of shape [batch_size, sequence_length]
with the labels token indices selected in the range [0, self.config.n_token[
`mems`: an optional memory of hidden states from previous forward passes
as a list (num layers) of hidden states at the entry of each layer
each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
Outputs:
A tuple of (last_hidden_state, new_mems)
`softmax_output`: output of the (adaptive) softmax:
if target is None:
if labels is None:
Negative log likelihood of shape [batch_size, sequence_length]
else:
log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
`new_mems`: list (num layers) of updated mem states at the entry of each layer
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
Example usage:
```python
@ -1375,16 +1375,16 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
def init_mems(self, data):
return self.transformer.init_mems(data)
def forward(self, input_ids, target=None, mems=None):
def forward(self, input_ids, labels=None, mems=None):
""" Params:
input_ids :: [bsz, len]
target :: [bsz, len]
labels :: [bsz, len]
Returns:
tuple(softmax_output, new_mems) where:
new_mems: list (num layers) of hidden states at the entry of each layer
shape :: [mem_len, bsz, self.config.d_model] :: Warning: shapes are transposed here w. regards to input_ids
softmax_output: output of the (adaptive) softmax:
if target is None:
if labels is None:
Negative log likelihood of shape :: [bsz, len]
else:
log probabilities of tokens, shape :: [bsz, len, n_tokens]
@ -1397,11 +1397,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
pred_hid = last_hidden[:, -tgt_len:]
if self.sample_softmax > 0 and self.training:
assert self.config.tie_weight
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, target, pred_hid, self.sampler)
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler)
softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
else:
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target)
if target is None:
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
if labels is None:
softmax_output = softmax_output.view(bsz, tgt_len, -1)
else:
softmax_output = softmax_output.view(bsz, tgt_len)

View File

@ -89,13 +89,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
return logit
def forward(self, hidden, target=None, keep_order=False):
def forward(self, hidden, labels=None, keep_order=False):
'''
Params:
hidden :: [len*bsz x d_proj]
target :: [len*bsz]
labels :: [len*bsz]
Return:
if target is None:
if labels is None:
out :: [len*bsz] Negative log likelihood
else:
out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary
@ -104,18 +104,18 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138
'''
if target is not None:
target = target.view(-1)
if hidden.size(0) != target.size(0):
raise RuntimeError('Input and target should have the same size '
if labels is not None:
labels = labels.view(-1)
if hidden.size(0) != labels.size(0):
raise RuntimeError('Input and labels should have the same size '
'in the batch dimension.')
if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0])
if target is not None:
if labels is not None:
out = -F.log_softmax(logit, dim=-1) \
.gather(1, target.unsqueeze(1)).squeeze(1)
.gather(1, labels.unsqueeze(1)).squeeze(1)
else:
out = F.log_softmax(logit, dim=-1)
else:
@ -144,31 +144,31 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
head_logprob = F.log_softmax(head_logit, dim=1)
if target is None:
if labels is None:
out = hidden.new_empty((head_logit.size(0), self.n_token))
else:
out = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device)
out = torch.zeros_like(labels, dtype=hidden.dtype, device=hidden.device)
offset = 0
cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
if target is not None:
mask_i = (target >= l_idx) & (target < r_idx)
if labels is not None:
mask_i = (labels >= l_idx) & (labels < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
target_i = target.index_select(0, indices_i) - l_idx
target_i = labels.index_select(0, indices_i) - l_idx
head_logprob_i = head_logprob.index_select(0, indices_i)
hidden_i = hidden.index_select(0, indices_i)
else:
hidden_i = hidden
if i == 0:
if target is not None:
if labels is not None:
logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
else:
out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
@ -178,14 +178,14 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
if target is not None:
if labels is not None:
logprob_i = head_logprob_i[:, cluster_prob_idx] \
+ tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
else:
logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i
out[:, l_idx:r_idx] = logprob_i
if target is not None:
if labels is not None:
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
out.index_copy_(0, indices_i, -logprob_i)
else:

View File

@ -1,111 +0,0 @@
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Utilities for PyTorch XLNet model.
"""
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
special_symbols = {
"<unk>" : 0,
"<s>" : 1,
"</s>" : 2,
"<cls>" : 3,
"<sep>" : 4,
"<pad>" : 5,
"<mask>" : 6,
"<eod>" : 7,
"<eop>" : 8,
}
VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]
def permutation_mask(inputs, targets, is_masked, perm_size, seq_len):
"""
Sample a permutation of the factorization order, and create an
attention mask accordingly.
Args:
inputs: int64 Tensor in shape [seq_len], input ids.
targets: int64 Tensor in shape [seq_len], target ids.
is_masked: bool Tensor in shape [seq_len]. True means being selected
for partial prediction.
perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length.
"""
# Generate permutation indices
index = np.arange(10)
index = np.transpose(np.reshape(index, [-1, perm_size]))
index = np.random.shuffle(index)
index = np.reshape(np.transpose(index), [-1])
# `perm_mask` and `target_mask`
# non-functional tokens
non_func_tokens = tf.logical_not(tf.logical_or(
tf.equal(inputs, SEP_ID),
tf.equal(inputs, CLS_ID)))
non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
masked_or_func_tokens = tf.logical_not(non_mask_tokens)
# Set the permutation indices of non-masked (& non-funcional) tokens to the
# smallest index (-1):
# (1) they can be seen by all other positions
# (2) they cannot see masked positions, so there won"t be information leak
smallest_index = -tf.ones([seq_len], dtype=tf.int64)
rev_index = tf.where(non_mask_tokens, smallest_index, index)
# Create `target_mask`: non-funcional and maksed tokens
# 1: use mask as input and have loss
# 0: use token (or [SEP], [CLS]) as input and do not have loss
target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
target_mask = tf.cast(target_tokens, tf.float32)
# Create `perm_mask`
# `target_tokens` cannot see themselves
self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)
# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
# 0: can attend if i > j or j is non-masked
perm_mask = tf.logical_and(
self_rev_index[:, None] <= rev_index[None, :],
masked_or_func_tokens)
perm_mask = tf.cast(perm_mask, tf.float32)
# new target: [next token] for LM and [curr token] (self) for PLM
new_targets = tf.concat([inputs[0: 1], targets[: -1]],
axis=0)
# construct inputs_k
inputs_k = inputs
# construct inputs_q
inputs_q = target_mask
return perm_mask, new_targets, target_mask, inputs_k, inputs_q

View File

@ -129,10 +129,10 @@ class TransfoXLModelTest(unittest.TestCase):
model = TransfoXLLMHeadModel(config)
model.eval()
loss_1, mems_1a = model(input_ids_1, target=lm_labels)
loss_1, mems_1a = model(input_ids_1, labels=lm_labels)
lm_logits_1, mems_1b = model(input_ids_1)
loss_2, mems_2a = model(input_ids_2, target=lm_labels, mems=mems_1a)
loss_2, mems_2a = model(input_ids_2, labels=lm_labels, mems=mems_1a)
lm_logits_2, mems_2b = model(input_ids_2, mems=mems_1b)
outputs = {

View File

@ -138,10 +138,10 @@ class XLNetModelTest(unittest.TestCase):
model = XLNetLMHeadModel(config)
model.eval()
loss_1, mems_1a = model(input_ids_1, token_type_ids=segment_ids, target=lm_labels)
loss_1, mems_1a = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
all_logits_1, mems_1b = model(input_ids_1, token_type_ids=segment_ids)
loss_2, mems_2a = model(input_ids_2, token_type_ids=segment_ids, target=lm_labels, mems=mems_1a)
loss_2, mems_2a = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1a)
all_logits_2, mems_2b = model(input_ids_2, token_type_ids=segment_ids, mems=mems_1b)
logits, _ = model(input_ids_q,