fix tests
This commit is contained in:
parent
092dacfd62
commit
93e9971c54
|
@ -930,12 +930,12 @@ all_hidden_states = lower_hidden_states + [hidden_states]
|
|||
`TransfoXLLMHeadModel` includes the `TransfoXLModel` Transformer followed by an (adaptive) softmax head with weights tied to the input embeddings.
|
||||
|
||||
*Inputs* are the same as the inputs of the [`TransfoXLModel`](#-12.-`TransfoXLModel`) class plus optional labels:
|
||||
- `target`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the target token indices selected in the range [0, self.config.n_token[
|
||||
- `labels`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the labels token indices selected in the range [0, self.config.n_token[
|
||||
|
||||
*Outputs* a tuple of (last_hidden_state, new_mems)
|
||||
- `softmax_output`: output of the (adaptive) softmax:
|
||||
- if target is None: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
|
||||
- else: Negative log likelihood of target tokens with shape [batch_size, sequence_length]
|
||||
- if labels is None: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
|
||||
- else: Negative log likelihood of labels tokens with shape [batch_size, sequence_length]
|
||||
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
|
||||
|
||||
#### 14. `GPT2Model`
|
||||
|
|
|
@ -1025,14 +1025,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||
`mems`: optional memomry of hidden states from previous forward passes
|
||||
as a list (num layers) of hidden states at the entry of each layer
|
||||
each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
|
||||
Outputs:
|
||||
A tuple of (last_hidden_state, new_mems)
|
||||
`last_hidden_state`: the encoded-hidden-states at the top of the model
|
||||
as a torch.FloatTensor of size [batch_size, sequence_length, self.config.d_model]
|
||||
`new_mems`: list (num layers) of updated mem states at the entry of each layer
|
||||
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
|
@ -1265,7 +1265,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||
mems :: optional mems from previous forwar passes (or init_mems)
|
||||
list (num layers) of mem states at the entry of each layer
|
||||
shape :: [self.config.mem_len, bsz, self.config.d_model]
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
|
||||
Returns:
|
||||
tuple (last_hidden, new_mems) where:
|
||||
new_mems: list (num layers) of mem states at the entry of each layer
|
||||
|
@ -1303,23 +1303,23 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the token indices selected in the range [0, self.config.n_token[
|
||||
`target`: an optional torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the target token indices selected in the range [0, self.config.n_token[
|
||||
`labels`: an optional torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the labels token indices selected in the range [0, self.config.n_token[
|
||||
`mems`: an optional memory of hidden states from previous forward passes
|
||||
as a list (num layers) of hidden states at the entry of each layer
|
||||
each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
|
||||
|
||||
Outputs:
|
||||
A tuple of (last_hidden_state, new_mems)
|
||||
`softmax_output`: output of the (adaptive) softmax:
|
||||
if target is None:
|
||||
if labels is None:
|
||||
Negative log likelihood of shape [batch_size, sequence_length]
|
||||
else:
|
||||
log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
|
||||
`new_mems`: list (num layers) of updated mem states at the entry of each layer
|
||||
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
|
@ -1375,16 +1375,16 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||
def init_mems(self, data):
|
||||
return self.transformer.init_mems(data)
|
||||
|
||||
def forward(self, input_ids, target=None, mems=None):
|
||||
def forward(self, input_ids, labels=None, mems=None):
|
||||
""" Params:
|
||||
input_ids :: [bsz, len]
|
||||
target :: [bsz, len]
|
||||
labels :: [bsz, len]
|
||||
Returns:
|
||||
tuple(softmax_output, new_mems) where:
|
||||
new_mems: list (num layers) of hidden states at the entry of each layer
|
||||
shape :: [mem_len, bsz, self.config.d_model] :: Warning: shapes are transposed here w. regards to input_ids
|
||||
softmax_output: output of the (adaptive) softmax:
|
||||
if target is None:
|
||||
if labels is None:
|
||||
Negative log likelihood of shape :: [bsz, len]
|
||||
else:
|
||||
log probabilities of tokens, shape :: [bsz, len, n_tokens]
|
||||
|
@ -1397,11 +1397,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||
pred_hid = last_hidden[:, -tgt_len:]
|
||||
if self.sample_softmax > 0 and self.training:
|
||||
assert self.config.tie_weight
|
||||
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, target, pred_hid, self.sampler)
|
||||
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler)
|
||||
softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
|
||||
else:
|
||||
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target)
|
||||
if target is None:
|
||||
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
|
||||
if labels is None:
|
||||
softmax_output = softmax_output.view(bsz, tgt_len, -1)
|
||||
else:
|
||||
softmax_output = softmax_output.view(bsz, tgt_len)
|
||||
|
|
|
@ -89,13 +89,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
|||
|
||||
return logit
|
||||
|
||||
def forward(self, hidden, target=None, keep_order=False):
|
||||
def forward(self, hidden, labels=None, keep_order=False):
|
||||
'''
|
||||
Params:
|
||||
hidden :: [len*bsz x d_proj]
|
||||
target :: [len*bsz]
|
||||
labels :: [len*bsz]
|
||||
Return:
|
||||
if target is None:
|
||||
if labels is None:
|
||||
out :: [len*bsz] Negative log likelihood
|
||||
else:
|
||||
out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary
|
||||
|
@ -104,18 +104,18 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
|||
here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138
|
||||
'''
|
||||
|
||||
if target is not None:
|
||||
target = target.view(-1)
|
||||
if hidden.size(0) != target.size(0):
|
||||
raise RuntimeError('Input and target should have the same size '
|
||||
if labels is not None:
|
||||
labels = labels.view(-1)
|
||||
if hidden.size(0) != labels.size(0):
|
||||
raise RuntimeError('Input and labels should have the same size '
|
||||
'in the batch dimension.')
|
||||
|
||||
if self.n_clusters == 0:
|
||||
logit = self._compute_logit(hidden, self.out_layers[0].weight,
|
||||
self.out_layers[0].bias, self.out_projs[0])
|
||||
if target is not None:
|
||||
if labels is not None:
|
||||
out = -F.log_softmax(logit, dim=-1) \
|
||||
.gather(1, target.unsqueeze(1)).squeeze(1)
|
||||
.gather(1, labels.unsqueeze(1)).squeeze(1)
|
||||
else:
|
||||
out = F.log_softmax(logit, dim=-1)
|
||||
else:
|
||||
|
@ -144,31 +144,31 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
|||
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
|
||||
head_logprob = F.log_softmax(head_logit, dim=1)
|
||||
|
||||
if target is None:
|
||||
if labels is None:
|
||||
out = hidden.new_empty((head_logit.size(0), self.n_token))
|
||||
else:
|
||||
out = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device)
|
||||
out = torch.zeros_like(labels, dtype=hidden.dtype, device=hidden.device)
|
||||
|
||||
offset = 0
|
||||
cutoff_values = [0] + self.cutoffs
|
||||
for i in range(len(cutoff_values) - 1):
|
||||
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
|
||||
|
||||
if target is not None:
|
||||
mask_i = (target >= l_idx) & (target < r_idx)
|
||||
if labels is not None:
|
||||
mask_i = (labels >= l_idx) & (labels < r_idx)
|
||||
indices_i = mask_i.nonzero().squeeze()
|
||||
|
||||
if indices_i.numel() == 0:
|
||||
continue
|
||||
|
||||
target_i = target.index_select(0, indices_i) - l_idx
|
||||
target_i = labels.index_select(0, indices_i) - l_idx
|
||||
head_logprob_i = head_logprob.index_select(0, indices_i)
|
||||
hidden_i = hidden.index_select(0, indices_i)
|
||||
else:
|
||||
hidden_i = hidden
|
||||
|
||||
if i == 0:
|
||||
if target is not None:
|
||||
if labels is not None:
|
||||
logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
|
||||
else:
|
||||
out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
|
||||
|
@ -178,14 +178,14 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
|||
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
|
||||
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
|
||||
cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
|
||||
if target is not None:
|
||||
if labels is not None:
|
||||
logprob_i = head_logprob_i[:, cluster_prob_idx] \
|
||||
+ tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
|
||||
else:
|
||||
logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i
|
||||
out[:, l_idx:r_idx] = logprob_i
|
||||
|
||||
if target is not None:
|
||||
if labels is not None:
|
||||
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
|
||||
out.index_copy_(0, indices_i, -logprob_i)
|
||||
else:
|
||||
|
|
|
@ -1,111 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Utilities for PyTorch XLNet model.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
special_symbols = {
|
||||
"<unk>" : 0,
|
||||
"<s>" : 1,
|
||||
"</s>" : 2,
|
||||
"<cls>" : 3,
|
||||
"<sep>" : 4,
|
||||
"<pad>" : 5,
|
||||
"<mask>" : 6,
|
||||
"<eod>" : 7,
|
||||
"<eop>" : 8,
|
||||
}
|
||||
|
||||
VOCAB_SIZE = 32000
|
||||
UNK_ID = special_symbols["<unk>"]
|
||||
CLS_ID = special_symbols["<cls>"]
|
||||
SEP_ID = special_symbols["<sep>"]
|
||||
MASK_ID = special_symbols["<mask>"]
|
||||
EOD_ID = special_symbols["<eod>"]
|
||||
|
||||
|
||||
def permutation_mask(inputs, targets, is_masked, perm_size, seq_len):
|
||||
"""
|
||||
Sample a permutation of the factorization order, and create an
|
||||
attention mask accordingly.
|
||||
Args:
|
||||
inputs: int64 Tensor in shape [seq_len], input ids.
|
||||
targets: int64 Tensor in shape [seq_len], target ids.
|
||||
is_masked: bool Tensor in shape [seq_len]. True means being selected
|
||||
for partial prediction.
|
||||
perm_size: the length of longest permutation. Could be set to be reuse_len.
|
||||
Should not be larger than reuse_len or there will be data leaks.
|
||||
seq_len: int, sequence length.
|
||||
"""
|
||||
|
||||
# Generate permutation indices
|
||||
index = np.arange(10)
|
||||
index = np.transpose(np.reshape(index, [-1, perm_size]))
|
||||
index = np.random.shuffle(index)
|
||||
index = np.reshape(np.transpose(index), [-1])
|
||||
|
||||
# `perm_mask` and `target_mask`
|
||||
# non-functional tokens
|
||||
non_func_tokens = tf.logical_not(tf.logical_or(
|
||||
tf.equal(inputs, SEP_ID),
|
||||
tf.equal(inputs, CLS_ID)))
|
||||
|
||||
non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
|
||||
masked_or_func_tokens = tf.logical_not(non_mask_tokens)
|
||||
|
||||
# Set the permutation indices of non-masked (& non-funcional) tokens to the
|
||||
# smallest index (-1):
|
||||
# (1) they can be seen by all other positions
|
||||
# (2) they cannot see masked positions, so there won"t be information leak
|
||||
smallest_index = -tf.ones([seq_len], dtype=tf.int64)
|
||||
rev_index = tf.where(non_mask_tokens, smallest_index, index)
|
||||
|
||||
# Create `target_mask`: non-funcional and maksed tokens
|
||||
# 1: use mask as input and have loss
|
||||
# 0: use token (or [SEP], [CLS]) as input and do not have loss
|
||||
target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
|
||||
target_mask = tf.cast(target_tokens, tf.float32)
|
||||
|
||||
# Create `perm_mask`
|
||||
# `target_tokens` cannot see themselves
|
||||
self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)
|
||||
|
||||
# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
|
||||
# 0: can attend if i > j or j is non-masked
|
||||
perm_mask = tf.logical_and(
|
||||
self_rev_index[:, None] <= rev_index[None, :],
|
||||
masked_or_func_tokens)
|
||||
perm_mask = tf.cast(perm_mask, tf.float32)
|
||||
|
||||
# new target: [next token] for LM and [curr token] (self) for PLM
|
||||
new_targets = tf.concat([inputs[0: 1], targets[: -1]],
|
||||
axis=0)
|
||||
|
||||
# construct inputs_k
|
||||
inputs_k = inputs
|
||||
|
||||
# construct inputs_q
|
||||
inputs_q = target_mask
|
||||
|
||||
return perm_mask, new_targets, target_mask, inputs_k, inputs_q
|
||||
|
|
@ -129,10 +129,10 @@ class TransfoXLModelTest(unittest.TestCase):
|
|||
model = TransfoXLLMHeadModel(config)
|
||||
model.eval()
|
||||
|
||||
loss_1, mems_1a = model(input_ids_1, target=lm_labels)
|
||||
loss_1, mems_1a = model(input_ids_1, labels=lm_labels)
|
||||
lm_logits_1, mems_1b = model(input_ids_1)
|
||||
|
||||
loss_2, mems_2a = model(input_ids_2, target=lm_labels, mems=mems_1a)
|
||||
loss_2, mems_2a = model(input_ids_2, labels=lm_labels, mems=mems_1a)
|
||||
lm_logits_2, mems_2b = model(input_ids_2, mems=mems_1b)
|
||||
|
||||
outputs = {
|
||||
|
|
|
@ -138,10 +138,10 @@ class XLNetModelTest(unittest.TestCase):
|
|||
model = XLNetLMHeadModel(config)
|
||||
model.eval()
|
||||
|
||||
loss_1, mems_1a = model(input_ids_1, token_type_ids=segment_ids, target=lm_labels)
|
||||
loss_1, mems_1a = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
|
||||
all_logits_1, mems_1b = model(input_ids_1, token_type_ids=segment_ids)
|
||||
|
||||
loss_2, mems_2a = model(input_ids_2, token_type_ids=segment_ids, target=lm_labels, mems=mems_1a)
|
||||
loss_2, mems_2a = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1a)
|
||||
all_logits_2, mems_2b = model(input_ids_2, token_type_ids=segment_ids, mems=mems_1b)
|
||||
|
||||
logits, _ = model(input_ids_q,
|
||||
|
|
Loading…
Reference in New Issue