diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index bce6a29e7a..9a7b154cef 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -47,7 +47,7 @@ class PretrainedConfig(object): Whether or not the model should return all hidden-states. output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not the model should returns all attentions. - use_cache (:obj:`bool`, `optional`, defaults to :obj:`False`): + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not the model should return tuples instead of :obj:`ModelOutput` objects. diff --git a/src/transformers/configuration_xlnet.py b/src/transformers/configuration_xlnet.py index a48ac7f48b..edd8925592 100644 --- a/src/transformers/configuration_xlnet.py +++ b/src/transformers/configuration_xlnet.py @@ -110,6 +110,8 @@ class XLNetConfig(PretrainedConfig): Used in the SQuAD evaluation script for XLM and XLNet. end_n_top (:obj:`int`, optional, defaults to 5): Used in the SQuAD evaluation script for XLM and XLNet. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Differs slightly from other models as it is always turned on at training time. Example:: diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index f599522f04..1748271521 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -575,7 +575,7 @@ class XLNetModelOutput(ModelOutput): ``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then ``num_predict`` corresponds to ``sequence_length``. mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks). + Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model should not be passed as input ids as they have already been computed. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -611,7 +611,7 @@ class XLNetLMHeadModelOutput(ModelOutput): ``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then ``num_predict`` corresponds to ``sequence_length``. mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks). + Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model should not be passed as input ids as they have already been computed. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -645,7 +645,7 @@ class XLNetForSequenceClassificationOutput(ModelOutput): logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks). + Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model should not be passed as input ids as they have already been computed. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -679,7 +679,7 @@ class XLNetForTokenClassificationOutput(ModelOutput): logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`): Classification scores (before SoftMax). mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks). + Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model should not be passed as input ids as they have already been computed. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -715,7 +715,7 @@ class XLNetForMultipleChoiceOutput(ModelOutput): Classification scores (before SoftMax). mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks). + Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model should not be passed as input ids as they have already been computed. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -751,7 +751,7 @@ class XLNetForQuestionAnsweringSimpleOutput(ModelOutput): end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): Span-end scores (before SoftMax). mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks). + Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model should not be passed as input ids as they have already been computed. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -794,7 +794,7 @@ class XLNetForQuestionAnsweringOutput(ModelOutput): cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): Log probabilities for the ``is_impossible`` label of the answers. mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks). + Contains pre-computed hidden-states. Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model should not be passed as input ids as they have already been computed. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -850,7 +850,7 @@ XLNET_INPUTS_DOCSTRING = r""" `What are attention masks? <../glossary.html#attention-mask>`__ mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): - Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + Contains pre-computed hidden-states as computed by the model (see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems given to this model should not be passed as input ids as they have already been computed. `use_cache` has to be set to `True` to make use of `mems`. @@ -964,10 +964,19 @@ class XLNetModel(XLNetPreTrainedModel): if self.reuse_len is not None and self.reuse_len > 0: curr_out = curr_out[: self.reuse_len] - if prev_mem is None: - new_mem = curr_out[-self.mem_len :] + if self.mem_len is None or self.mem_len == 0: + # If `use_cache` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time + # and returns all of the past and current hidden states. + cutoff = 0 else: - new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len :] + # If `use_cache` is active and `mem_len` is defined, the model returns the last `mem_len` hidden + # states. This is the preferred setting for training and long-form generation. + cutoff = -self.mem_len + if prev_mem is None: + # if `use_cache` is active and `mem_len` is defined, the model + new_mem = curr_out[cutoff:] + else: + new_mem = torch.cat([prev_mem, curr_out], dim=0)[cutoff:] return new_mem.detach() @@ -1039,7 +1048,7 @@ class XLNetModel(XLNetPreTrainedModel): input_mask=None, head_mask=None, inputs_embeds=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, return_tuple=None, @@ -1049,6 +1058,7 @@ class XLNetModel(XLNetPreTrainedModel): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple + use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache) # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # but we want a unified interface in the library with the batch size on the first dimension @@ -1179,7 +1189,7 @@ class XLNetModel(XLNetPreTrainedModel): attentions = [] if output_attentions else None hidden_states = [] if output_hidden_states else None for i, layer_module in enumerate(self.layer): - if self.mem_len is not None and self.mem_len > 0 and use_cache is True: + if use_cache: # cache new mems new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) if output_hidden_states: @@ -1211,7 +1221,7 @@ class XLNetModel(XLNetPreTrainedModel): output = output.permute(1, 0, 2).contiguous() # TODO Teven: fix this test to only use use_cache. - if not (self.mem_len is not None and self.mem_len > 0 and use_cache is True): + if not use_cache: new_mems = None if output_hidden_states: @@ -1312,7 +1322,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): head_mask=None, inputs_embeds=None, labels=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, return_tuple=None, @@ -1360,6 +1370,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): """ return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple + use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache) transformer_outputs = self.transformer( input_ids, @@ -1433,7 +1444,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): head_mask=None, inputs_embeds=None, labels=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, return_tuple=None, @@ -1446,6 +1457,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). """ return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple + use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache) transformer_outputs = self.transformer( input_ids, @@ -1524,7 +1536,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel): head_mask=None, inputs_embeds=None, labels=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, return_tuple=None, @@ -1536,6 +1548,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel): of the input tensors. (see `input_ids` above) """ return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple + use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache) outputs = self.transformer( input_ids, @@ -1618,7 +1631,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): head_mask=None, inputs_embeds=None, labels=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, return_tuple=None, @@ -1630,6 +1643,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): of the input tensors. (see `input_ids` above) """ return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple + use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache) num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -1717,7 +1731,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): inputs_embeds=None, start_positions=None, end_positions=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, return_tuple=None, @@ -1733,6 +1747,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): Position outside of the sequence are not taken into account for computing the loss. """ return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple + use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache) outputs = self.transformer( input_ids, @@ -1824,7 +1839,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): is_impossible=None, cls_index=None, p_mask=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, return_tuple=None, @@ -1864,6 +1879,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): >>> loss = outputs[0] """ return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple + use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache) transformer_outputs = self.transformer( input_ids, diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 5ca428fb25..94ca1fea33 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -191,8 +191,8 @@ class XLNetModelTester: model = XLNetModel(config) model.to(torch_device) model.eval() - no_mems_outputs = model(input_ids_1) - self.parent.assertEqual(len(no_mems_outputs), 1) + base_model_output = model(input_ids_1) + self.parent.assertEqual(len(base_model_output), 2) self.parent.assertListEqual( list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size], @@ -202,6 +202,72 @@ class XLNetModelTester: [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, ) + def create_and_check_xlnet_model_use_cache( + self, + config, + input_ids_1, + input_ids_2, + input_ids_q, + perm_mask, + input_mask, + target_mapping, + segment_ids, + lm_labels, + sequence_labels, + is_impossible_labels, + token_labels, + ): + model = XLNetModel(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + causal_mask = torch.ones( + input_ids_1.shape[0], + input_ids_1.shape[1], + input_ids_1.shape[1], + dtype=torch.float, + device=input_ids_1.device, + ) + causal_mask = torch.triu(causal_mask, diagonal=0) + outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask) + outputs_no_cache = model(input_ids_1, use_cache=False, perm_mask=causal_mask) + outputs_conf = model(input_ids_1) + + self.parent.assertTrue(len(outputs_cache) == len(outputs_conf)) + self.parent.assertTrue(len(outputs_cache) == len(outputs_no_cache) + 1) + + output, mems = outputs_cache + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and token_type_ids + next_input_ids = torch.cat([input_ids_1, next_tokens], dim=-1) + + # causal mask + causal_mask = torch.ones( + input_ids_1.shape[0], + input_ids_1.shape[1] + 1, + input_ids_1.shape[1] + 1, + dtype=torch.float, + device=input_ids_1.device, + ) + causal_mask = torch.triu(causal_mask, diagonal=0) + single_mask = torch.ones(input_ids_1.shape[0], 1, 1) + + # second forward pass + output_from_no_past, _ = model(next_input_ids, perm_mask=causal_mask) + output_from_past, _ = model(next_tokens, mems=mems, perm_mask=single_mask) + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + def create_and_check_xlnet_base_model_with_att_output( self, config, @@ -451,7 +517,6 @@ class XLNetModelTester: @require_torch class XLNetModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = ( ( XLNetModel, @@ -482,6 +547,12 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs) + def test_xlnet_base_model_use_cache(self): + # checking that in auto-regressive mode, `use_cache` gives the same results + self.model_tester.set_seed() + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_xlnet_model_use_cache(*config_and_inputs) + def test_xlnet_base_model_with_att_output(self): self.model_tester.set_seed() config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -874,33 +945,33 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase): 9, 69, 27, - 50, - 551, + 442, 22, 2771, - 4901, - 19, - 21, - 45, - 668, - 21, + 24, + 11335, + 20, 18, - 416, - 41, - 1499, - 22, - 755, - 18, - 14285, + 9225, + 2198, 9, - 12943, - 4354, - 153, + 69, 27, - 1499, + 442, 22, - 642, + 2771, + 24, + 11335, + 20, + 18, + 9225, + 2198, + 9, + 69, + 27, + 442, 22, + 2771, ] # In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) # are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, @@ -910,9 +981,8 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase): # him for making such an accusation, Rasputin watches as the man is chased outside and beaten. # Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest. # Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing. - # , Rasputin is asked to perform magic. - # He is not able to perform magic, and his father and - # the men are forced to leave the monastery. Rasputin is forced to return to + # , Rasputin is asked to perform magic. He is asked to perform a ritual of the Virgin Mary. + # He is asked to perform a ritual of the Virgin Mary. He is asked to perform output_ids = model.generate(input_ids, max_length=200, do_sample=False) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)