parent
2aa9c2f204
commit
ccd1923f46
|
@ -640,6 +640,11 @@ class T5Block(nn.Module):
|
|||
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
||||
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
|
||||
|
||||
# clamp inf values to enable fp16 training
|
||||
if torch.isinf(hidden_states).any():
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
|
||||
if do_cross_attention:
|
||||
# the actual query length is unknown for cross attention
|
||||
|
@ -661,6 +666,10 @@ class T5Block(nn.Module):
|
|||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = cross_attention_outputs[0]
|
||||
if torch.isinf(hidden_states).any():
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
# Combine self attn and cross attn key value states
|
||||
if present_key_value_state is not None:
|
||||
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
|
||||
|
@ -670,6 +679,9 @@ class T5Block(nn.Module):
|
|||
|
||||
# Apply Feed Forward layer
|
||||
hidden_states = self.layer[-1](hidden_states)
|
||||
if torch.isinf(hidden_states).any():
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
outputs = (hidden_states,)
|
||||
|
||||
outputs = outputs + (present_key_value_state,) + attention_outputs
|
||||
|
|
Loading…
Reference in New Issue