* Add 3D attention mask to T5 model (#9643) Added code for 3D attention mask in T5 model. Similar to BERT model. * Add test for 3D attention mask Added test for 3D attention mask: test_decoder_model_past_with_3d_attn_mask() 3D attention mask of the shape [Batch_size, Seq_length, Seq_length] both for attention mask and decoder attention mask. Test is passing.
This commit is contained in:
parent
6ee1a4fd3e
commit
91cf29153b
|
@ -914,7 +914,13 @@ class T5Stack(T5PreTrainedModel):
|
|||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
|
||||
|
||||
if self.is_decoder and encoder_attention_mask is not None:
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
|
|
@ -530,6 +530,34 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||
|
||||
def test_decoder_model_past_with_3d_attn_mask(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
) = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
attention_mask = ids_tensor(
|
||||
[self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length],
|
||||
vocab_size=2,
|
||||
)
|
||||
decoder_attention_mask = ids_tensor(
|
||||
[self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length],
|
||||
vocab_size=2,
|
||||
)
|
||||
|
||||
self.model_tester.create_and_check_decoder_model_attention_mask_past(
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
)
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
|
Loading…
Reference in New Issue