[T5] Add 3D attention mask to T5 model (2) (#9643) (#11197)

* Add 3D attention mask to T5 model (#9643)

Added code for 3D attention mask in T5 model. Similar to BERT model.

* Add test for 3D attention mask

Added test for 3D attention mask: test_decoder_model_past_with_3d_attn_mask()
3D attention mask of the shape [Batch_size, Seq_length, Seq_length] both for
attention mask and decoder attention mask. Test is passing.
This commit is contained in:
lexhuismans 2021-05-13 13:02:27 +02:00 committed by GitHub
parent 6ee1a4fd3e
commit 91cf29153b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 1 deletions

View File

@ -914,7 +914,13 @@ class T5Stack(T5PreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is not None:
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None

View File

@ -530,6 +530,34 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_decoder_model_past_with_3d_attn_mask(self):
(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
) = self.model_tester.prepare_config_and_inputs()
attention_mask = ids_tensor(
[self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length],
vocab_size=2,
)
decoder_attention_mask = ids_tensor(
[self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length],
vocab_size=2,
)
self.model_tester.create_and_check_decoder_model_attention_mask_past(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
)
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)