fix T5 head mask in model_parallel (#9726)
* fix head mask in model_parallel * pass correct head mask
This commit is contained in:
parent
ca422e3d7d
commit
248fa1ae72
|
@ -920,6 +920,8 @@ class T5Stack(T5PreTrainedModel):
|
|||
hidden_states = self.dropout(inputs_embeds)
|
||||
|
||||
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
|
||||
layer_head_mask = head_mask[i]
|
||||
encoder_layer_head_mask = encoder_head_mask[i]
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
|
@ -934,10 +936,10 @@ class T5Stack(T5PreTrainedModel):
|
|||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
|
||||
if encoder_decoder_position_bias is not None:
|
||||
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
|
||||
if not (isinstance(head_mask, list) and head_mask[0] is None):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if not (isinstance(encoder_head_mask, list) and encoder_head_mask[0] is None):
|
||||
encoder_head_mask = encoder_head_mask.to(hidden_states.device)
|
||||
if layer_head_mask is not None:
|
||||
layer_head_mask = layer_head_mask.to(hidden_states.device)
|
||||
if encoder_layer_head_mask is not None:
|
||||
encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
|
@ -948,8 +950,8 @@ class T5Stack(T5PreTrainedModel):
|
|||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
layer_head_mask=head_mask[i],
|
||||
encoder_layer_head_mask=encoder_head_mask[i] if encoder_head_mask is not None else None,
|
||||
layer_head_mask=layer_head_mask,
|
||||
encoder_layer_head_mask=encoder_layer_head_mask,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
|
|
Loading…
Reference in New Issue