fix T5 head mask in model_parallel (#9726)

* fix head mask in model_parallel

* pass correct head mask
This commit is contained in:
Suraj Patil 2021-01-21 16:46:14 +05:30 committed by GitHub
parent ca422e3d7d
commit 248fa1ae72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 6 deletions

View File

@ -920,6 +920,8 @@ class T5Stack(T5PreTrainedModel):
hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
layer_head_mask = head_mask[i]
encoder_layer_head_mask = encoder_head_mask[i]
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
@ -934,10 +936,10 @@ class T5Stack(T5PreTrainedModel):
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
if encoder_decoder_position_bias is not None:
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if not (isinstance(head_mask, list) and head_mask[0] is None):
head_mask = head_mask.to(hidden_states.device)
if not (isinstance(encoder_head_mask, list) and encoder_head_mask[0] is None):
encoder_head_mask = encoder_head_mask.to(hidden_states.device)
if layer_head_mask is not None:
layer_head_mask = layer_head_mask.to(hidden_states.device)
if encoder_layer_head_mask is not None:
encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -948,8 +950,8 @@ class T5Stack(T5PreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=head_mask[i],
encoder_layer_head_mask=encoder_head_mask[i] if encoder_head_mask is not None else None,
layer_head_mask=layer_head_mask,
encoder_layer_head_mask=encoder_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,