diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index f27ede05e1..1eb35a2875 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -920,6 +920,8 @@ class T5Stack(T5PreTrainedModel): hidden_states = self.dropout(inputs_embeds) for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + encoder_layer_head_mask = encoder_head_mask[i] # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) @@ -934,10 +936,10 @@ class T5Stack(T5PreTrainedModel): encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) if encoder_decoder_position_bias is not None: encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) - if not (isinstance(head_mask, list) and head_mask[0] is None): - head_mask = head_mask.to(hidden_states.device) - if not (isinstance(encoder_head_mask, list) and encoder_head_mask[0] is None): - encoder_head_mask = encoder_head_mask.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if encoder_layer_head_mask is not None: + encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -948,8 +950,8 @@ class T5Stack(T5PreTrainedModel): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=head_mask[i], - encoder_layer_head_mask=encoder_head_mask[i] if encoder_head_mask is not None else None, + layer_head_mask=layer_head_mask, + encoder_layer_head_mask=encoder_layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions,