diff --git a/src/transformers/models/albert/modeling_flax_albert.py b/src/transformers/models/albert/modeling_flax_albert.py index 84b86fa563..03ff7d0170 100644 --- a/src/transformers/models/albert/modeling_flax_albert.py +++ b/src/transformers/models/albert/modeling_flax_albert.py @@ -245,7 +245,7 @@ class FlaxAlbertSelfAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index a17e4185fd..4963f95480 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -371,7 +371,7 @@ class FlaxBartAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index f7c78632e5..0206e56c17 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -358,7 +358,7 @@ class FlaxBertSelfAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index b38492f61f..49fb61118d 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -380,7 +380,7 @@ class FlaxBigBirdSelfAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py index 8abacc090b..8f26e80c9f 100644 --- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py @@ -359,7 +359,7 @@ class FlaxBlenderbotAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py index fc7c938512..4115460b7a 100644 --- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py @@ -371,7 +371,7 @@ class FlaxBlenderbotSmallAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/clip/modeling_flax_clip.py b/src/transformers/models/clip/modeling_flax_clip.py index aa8ef87d5b..9548557fb6 100644 --- a/src/transformers/models/clip/modeling_flax_clip.py +++ b/src/transformers/models/clip/modeling_flax_clip.py @@ -315,7 +315,7 @@ class FlaxCLIPAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e4).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 99f193f590..a7ecb97414 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -326,7 +326,7 @@ class FlaxElectraSelfAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/gpt2/modeling_flax_gpt2.py b/src/transformers/models/gpt2/modeling_flax_gpt2.py index 6a2d6f553c..796ca8ac1a 100644 --- a/src/transformers/models/gpt2/modeling_flax_gpt2.py +++ b/src/transformers/models/gpt2/modeling_flax_gpt2.py @@ -255,7 +255,7 @@ class FlaxGPT2Attention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e4).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py index 20505c511f..bb29b3c649 100644 --- a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py @@ -223,7 +223,7 @@ class FlaxGPTNeoSelfAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) # usual dot product attention diff --git a/src/transformers/models/gptj/modeling_flax_gptj.py b/src/transformers/models/gptj/modeling_flax_gptj.py index e7683c169d..a10e9598ce 100644 --- a/src/transformers/models/gptj/modeling_flax_gptj.py +++ b/src/transformers/models/gptj/modeling_flax_gptj.py @@ -270,7 +270,7 @@ class FlaxGPTJAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) # usual dot product attention diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index 74f6b1fe8d..064879bc40 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -381,7 +381,7 @@ class FlaxMarianAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py index a99b58fb5b..4d3156b9d1 100644 --- a/src/transformers/models/mbart/modeling_flax_mbart.py +++ b/src/transformers/models/mbart/modeling_flax_mbart.py @@ -383,7 +383,7 @@ class FlaxMBartAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/opt/modeling_flax_opt.py b/src/transformers/models/opt/modeling_flax_opt.py index 83af4da8b0..79d1ff7193 100644 --- a/src/transformers/models/opt/modeling_flax_opt.py +++ b/src/transformers/models/opt/modeling_flax_opt.py @@ -245,7 +245,7 @@ class FlaxOPTAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py index ab246b5ef8..3eddce4dd4 100644 --- a/src/transformers/models/pegasus/modeling_flax_pegasus.py +++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py @@ -375,7 +375,7 @@ class FlaxPegasusAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e9).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index b7494e19f4..56ba89b01e 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -319,7 +319,7 @@ class FlaxRobertaSelfAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py index 68cf1b7ca5..a01dfc520a 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py @@ -321,7 +321,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/roformer/modeling_flax_roformer.py b/src/transformers/models/roformer/modeling_flax_roformer.py index 011f161048..13dd8548e4 100644 --- a/src/transformers/models/roformer/modeling_flax_roformer.py +++ b/src/transformers/models/roformer/modeling_flax_roformer.py @@ -240,7 +240,7 @@ class FlaxRoFormerSelfAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index 03496b8210..0d718b8c1f 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -497,7 +497,7 @@ class FlaxWav2Vec2Attention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py index 6b86ac47af..fdb58bce49 100644 --- a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py @@ -329,7 +329,7 @@ class FlaxXLMRobertaSelfAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py index 676270c131..f6283197b0 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py @@ -312,7 +312,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None @@ -1859,7 +1859,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module): attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None