Flax dtype-dependent numerical masking (#21197)

This commit is contained in:
Joao Gante 2023-01-19 16:43:42 +00:00 committed by GitHub
parent 0b86e330b1
commit cbaaa2f6ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 22 additions and 22 deletions

View File

@ -245,7 +245,7 @@ class FlaxAlbertSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -371,7 +371,7 @@ class FlaxBartAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -358,7 +358,7 @@ class FlaxBertSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -380,7 +380,7 @@ class FlaxBigBirdSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -359,7 +359,7 @@ class FlaxBlenderbotAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -371,7 +371,7 @@ class FlaxBlenderbotSmallAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -315,7 +315,7 @@ class FlaxCLIPAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e4).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -326,7 +326,7 @@ class FlaxElectraSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -255,7 +255,7 @@ class FlaxGPT2Attention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e4).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -223,7 +223,7 @@ class FlaxGPTNeoSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
# usual dot product attention # usual dot product attention

View File

@ -270,7 +270,7 @@ class FlaxGPTJAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
# usual dot product attention # usual dot product attention

View File

@ -381,7 +381,7 @@ class FlaxMarianAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -383,7 +383,7 @@ class FlaxMBartAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -245,7 +245,7 @@ class FlaxOPTAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -375,7 +375,7 @@ class FlaxPegasusAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e9).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -319,7 +319,7 @@ class FlaxRobertaSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -321,7 +321,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -240,7 +240,7 @@ class FlaxRoFormerSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -497,7 +497,7 @@ class FlaxWav2Vec2Attention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -329,7 +329,7 @@ class FlaxXLMRobertaSelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None

View File

@ -312,7 +312,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None
@ -1859,7 +1859,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
) )
else: else:
attention_bias = None attention_bias = None