Flax dtype-dependent numerical masking (#21197)
This commit is contained in:
parent
0b86e330b1
commit
cbaaa2f6ac
|
@ -245,7 +245,7 @@ class FlaxAlbertSelfAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -371,7 +371,7 @@ class FlaxBartAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -358,7 +358,7 @@ class FlaxBertSelfAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -380,7 +380,7 @@ class FlaxBigBirdSelfAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -359,7 +359,7 @@ class FlaxBlenderbotAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -371,7 +371,7 @@ class FlaxBlenderbotSmallAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -315,7 +315,7 @@ class FlaxCLIPAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -326,7 +326,7 @@ class FlaxElectraSelfAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -255,7 +255,7 @@ class FlaxGPT2Attention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -223,7 +223,7 @@ class FlaxGPTNeoSelfAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
# usual dot product attention
|
# usual dot product attention
|
||||||
|
|
|
@ -270,7 +270,7 @@ class FlaxGPTJAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
# usual dot product attention
|
# usual dot product attention
|
||||||
|
|
|
@ -381,7 +381,7 @@ class FlaxMarianAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -383,7 +383,7 @@ class FlaxMBartAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -245,7 +245,7 @@ class FlaxOPTAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -375,7 +375,7 @@ class FlaxPegasusAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -319,7 +319,7 @@ class FlaxRobertaSelfAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -321,7 +321,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -240,7 +240,7 @@ class FlaxRoFormerSelfAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -497,7 +497,7 @@ class FlaxWav2Vec2Attention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -329,7 +329,7 @@ class FlaxXLMRobertaSelfAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
|
@ -312,7 +312,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
@ -1859,7 +1859,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||||
attention_bias = lax.select(
|
attention_bias = lax.select(
|
||||||
attention_mask > 0,
|
attention_mask > 0,
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attention_bias = None
|
attention_bias = None
|
||||||
|
|
Loading…
Reference in New Issue