Xlnet outputs (#5883)

Slightly breaking change, changes functionality for `use_cache` in XLNet: if use_cache is True and mem_len is 0 or None (which is the case in the base model config), the model behaves like GPT-2 and returns mems to be used as past in generation. At training time `use_cache` is overriden and always True.
This commit is contained in:
Teven 2020-07-18 17:33:13 +02:00 committed by GitHub
parent a55809241f
commit 4b506a37e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 131 additions and 47 deletions

View File

@ -47,7 +47,7 @@ class PretrainedConfig(object):
Whether or not the model should return all hidden-states.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should returns all attentions.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`False`):
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should return tuples instead of :obj:`ModelOutput` objects.

View File

@ -110,6 +110,8 @@ class XLNetConfig(PretrainedConfig):
Used in the SQuAD evaluation script for XLM and XLNet.
end_n_top (:obj:`int`, optional, defaults to 5):
Used in the SQuAD evaluation script for XLM and XLNet.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Differs slightly from other models as it is always turned on at training time.
Example::

View File

@ -575,7 +575,7 @@ class XLNetModelOutput(ModelOutput):
``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
``num_predict`` corresponds to ``sequence_length``.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Contains pre-computed hidden-states.
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
@ -611,7 +611,7 @@ class XLNetLMHeadModelOutput(ModelOutput):
``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
``num_predict`` corresponds to ``sequence_length``.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Contains pre-computed hidden-states.
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
@ -645,7 +645,7 @@ class XLNetForSequenceClassificationOutput(ModelOutput):
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Contains pre-computed hidden-states.
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
@ -679,7 +679,7 @@ class XLNetForTokenClassificationOutput(ModelOutput):
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Contains pre-computed hidden-states.
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
@ -715,7 +715,7 @@ class XLNetForMultipleChoiceOutput(ModelOutput):
Classification scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Contains pre-computed hidden-states.
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
@ -751,7 +751,7 @@ class XLNetForQuestionAnsweringSimpleOutput(ModelOutput):
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Contains pre-computed hidden-states.
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
@ -794,7 +794,7 @@ class XLNetForQuestionAnsweringOutput(ModelOutput):
cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the ``is_impossible`` label of the answers.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Contains pre-computed hidden-states.
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
@ -850,7 +850,7 @@ XLNET_INPUTS_DOCSTRING = r"""
`What are attention masks? <../glossary.html#attention-mask>`__
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
Contains pre-computed hidden-states as computed by the model
(see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
given to this model should not be passed as input ids as they have already been computed.
`use_cache` has to be set to `True` to make use of `mems`.
@ -964,10 +964,19 @@ class XLNetModel(XLNetPreTrainedModel):
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[: self.reuse_len]
if prev_mem is None:
new_mem = curr_out[-self.mem_len :]
if self.mem_len is None or self.mem_len == 0:
# If `use_cache` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
# and returns all of the past and current hidden states.
cutoff = 0
else:
new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len :]
# If `use_cache` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
# states. This is the preferred setting for training and long-form generation.
cutoff = -self.mem_len
if prev_mem is None:
# if `use_cache` is active and `mem_len` is defined, the model
new_mem = curr_out[cutoff:]
else:
new_mem = torch.cat([prev_mem, curr_out], dim=0)[cutoff:]
return new_mem.detach()
@ -1039,7 +1048,7 @@ class XLNetModel(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
@ -1049,6 +1058,7 @@ class XLNetModel(XLNetPreTrainedModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
@ -1179,7 +1189,7 @@ class XLNetModel(XLNetPreTrainedModel):
attentions = [] if output_attentions else None
hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layer):
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
if use_cache:
# cache new mems
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if output_hidden_states:
@ -1211,7 +1221,7 @@ class XLNetModel(XLNetPreTrainedModel):
output = output.permute(1, 0, 2).contiguous()
# TODO Teven: fix this test to only use use_cache.
if not (self.mem_len is not None and self.mem_len > 0 and use_cache is True):
if not use_cache:
new_mems = None
if output_hidden_states:
@ -1312,7 +1322,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=True,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
@ -1360,6 +1370,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
transformer_outputs = self.transformer(
input_ids,
@ -1433,7 +1444,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=True,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
@ -1446,6 +1457,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
transformer_outputs = self.transformer(
input_ids,
@ -1524,7 +1536,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=True,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
@ -1536,6 +1548,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
of the input tensors. (see `input_ids` above)
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
outputs = self.transformer(
input_ids,
@ -1618,7 +1631,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=True,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
@ -1630,6 +1643,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
of the input tensors. (see `input_ids` above)
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@ -1717,7 +1731,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
inputs_embeds=None,
start_positions=None,
end_positions=None,
use_cache=True,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
@ -1733,6 +1747,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
Position outside of the sequence are not taken into account for computing the loss.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
outputs = self.transformer(
input_ids,
@ -1824,7 +1839,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
is_impossible=None,
cls_index=None,
p_mask=None,
use_cache=True,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
@ -1864,6 +1879,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
>>> loss = outputs[0]
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
transformer_outputs = self.transformer(
input_ids,

View File

@ -191,8 +191,8 @@ class XLNetModelTester:
model = XLNetModel(config)
model.to(torch_device)
model.eval()
no_mems_outputs = model(input_ids_1)
self.parent.assertEqual(len(no_mems_outputs), 1)
base_model_output = model(input_ids_1)
self.parent.assertEqual(len(base_model_output), 2)
self.parent.assertListEqual(
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size],
@ -202,6 +202,68 @@ class XLNetModelTester:
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
def create_and_check_xlnet_model_use_cache(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetModel(config=config)
model.to(torch_device)
model.eval()
# first forward pass
causal_mask = torch.ones(
input_ids_1.shape[0], input_ids_1.shape[1], input_ids_1.shape[1], dtype=torch.float, device=torch_device,
)
causal_mask = torch.triu(causal_mask, diagonal=0)
outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask)
outputs_no_cache = model(input_ids_1, use_cache=False, perm_mask=causal_mask)
outputs_conf = model(input_ids_1)
self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
self.parent.assertTrue(len(outputs_cache) == len(outputs_no_cache) + 1)
output, mems = outputs_cache
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and token_type_ids
next_input_ids = torch.cat([input_ids_1, next_tokens], dim=-1)
# causal mask
causal_mask = torch.ones(
input_ids_1.shape[0],
input_ids_1.shape[1] + 1,
input_ids_1.shape[1] + 1,
dtype=torch.float,
device=torch_device,
)
causal_mask = torch.triu(causal_mask, diagonal=0)
single_mask = torch.ones(input_ids_1.shape[0], 1, 1, dtype=torch.float, device=torch_device)
# second forward pass
output_from_no_past, _ = model(next_input_ids, perm_mask=causal_mask)
output_from_past, _ = model(next_tokens, mems=mems, perm_mask=single_mask)
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_xlnet_base_model_with_att_output(
self,
config,
@ -451,7 +513,6 @@ class XLNetModelTester:
@require_torch
class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
XLNetModel,
@ -482,6 +543,12 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
def test_xlnet_base_model_use_cache(self):
# checking that in auto-regressive mode, `use_cache` gives the same results
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_model_use_cache(*config_and_inputs)
def test_xlnet_base_model_with_att_output(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@ -874,33 +941,33 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
9,
69,
27,
50,
551,
442,
22,
2771,
4901,
19,
21,
45,
668,
21,
24,
11335,
20,
18,
416,
41,
1499,
22,
755,
18,
14285,
9225,
2198,
9,
12943,
4354,
153,
69,
27,
1499,
442,
22,
642,
2771,
24,
11335,
20,
18,
9225,
2198,
9,
69,
27,
442,
22,
2771,
]
# In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria)
# are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
@ -910,9 +977,8 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
# him for making such an accusation, Rasputin watches as the man is chased outside and beaten.
# Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest.
# Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing.
# <sep><cls>, Rasputin is asked to perform magic.
# He is not able to perform magic, and his father and
# the men are forced to leave the monastery. Rasputin is forced to return to
# <sep><cls>, Rasputin is asked to perform magic. He is asked to perform a ritual of the Virgin Mary.
# He is asked to perform a ritual of the Virgin Mary. He is asked to perform
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)