From 7eee950ac3135476c811daf23eca8cedbbaa3879 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 6 May 2021 14:24:19 -0400 Subject: [PATCH] Re-styling in seq2seq attention (#11613) --- src/transformers/models/bart/modeling_bart.py | 43 ++++++++----------- .../models/blenderbot/modeling_blenderbot.py | 43 ++++++++----------- .../modeling_blenderbot_small.py | 43 ++++++++----------- .../models/m2m_100/modeling_m2m_100.py | 43 ++++++++----------- .../models/marian/modeling_marian.py | 43 ++++++++----------- .../models/mbart/modeling_mbart.py | 43 ++++++++----------- .../models/pegasus/modeling_pegasus.py | 43 ++++++++----------- .../speech_to_text/modeling_speech_to_text.py | 43 ++++++++----------- .../models/wav2vec2/modeling_wav2vec2.py | 43 ++++++++----------- ...ng_{{cookiecutter.lowercase_modelname}}.py | 43 ++++++++----------- 10 files changed, 190 insertions(+), 240 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 89e078bd9e..8f72c64d43 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -210,28 +210,26 @@ class BartAttention(nn.Module): src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - assert attn_weights.size() == ( - bsz * self.num_heads, - tgt_len, - src_len, - ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) if attention_mask is not None: - assert attention_mask.size() == ( - bsz, - 1, - tgt_len, - src_len, - ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if layer_head_mask is not None: - assert layer_head_mask.size() == ( - self.num_heads, - ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -249,17 +247,14 @@ class BartAttention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - assert attn_output.size() == ( - bsz * self.num_heads, - tgt_len, - self.head_dim, - ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) - attn_output = ( - attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - .transpose(1, 2) - .reshape(bsz, tgt_len, embed_dim) - ) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 461084ea73..5620c77887 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -211,28 +211,26 @@ class BlenderbotAttention(nn.Module): src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - assert attn_weights.size() == ( - bsz * self.num_heads, - tgt_len, - src_len, - ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) if attention_mask is not None: - assert attention_mask.size() == ( - bsz, - 1, - tgt_len, - src_len, - ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if layer_head_mask is not None: - assert layer_head_mask.size() == ( - self.num_heads, - ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -250,17 +248,14 @@ class BlenderbotAttention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - assert attn_output.size() == ( - bsz * self.num_heads, - tgt_len, - self.head_dim, - ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) - attn_output = ( - attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - .transpose(1, 2) - .reshape(bsz, tgt_len, embed_dim) - ) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index d32a98ec73..7ddc2e7650 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -209,28 +209,26 @@ class BlenderbotSmallAttention(nn.Module): src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - assert attn_weights.size() == ( - bsz * self.num_heads, - tgt_len, - src_len, - ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) if attention_mask is not None: - assert attention_mask.size() == ( - bsz, - 1, - tgt_len, - src_len, - ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if layer_head_mask is not None: - assert layer_head_mask.size() == ( - self.num_heads, - ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -248,17 +246,14 @@ class BlenderbotSmallAttention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - assert attn_output.size() == ( - bsz * self.num_heads, - tgt_len, - self.head_dim, - ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) - attn_output = ( - attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - .transpose(1, 2) - .reshape(bsz, tgt_len, embed_dim) - ) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 20c4aea990..4db2be333b 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -280,28 +280,26 @@ class M2M100Attention(nn.Module): src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - assert attn_weights.size() == ( - bsz * self.num_heads, - tgt_len, - src_len, - ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) if attention_mask is not None: - assert attention_mask.size() == ( - bsz, - 1, - tgt_len, - src_len, - ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if layer_head_mask is not None: - assert layer_head_mask.size() == ( - self.num_heads, - ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -319,17 +317,14 @@ class M2M100Attention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - assert attn_output.size() == ( - bsz * self.num_heads, - tgt_len, - self.head_dim, - ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) - attn_output = ( - attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - .transpose(1, 2) - .reshape(bsz, tgt_len, embed_dim) - ) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index c99d4aa832..dc40dacc40 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -226,28 +226,26 @@ class MarianAttention(nn.Module): src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - assert attn_weights.size() == ( - bsz * self.num_heads, - tgt_len, - src_len, - ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) if attention_mask is not None: - assert attention_mask.size() == ( - bsz, - 1, - tgt_len, - src_len, - ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if layer_head_mask is not None: - assert layer_head_mask.size() == ( - self.num_heads, - ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -265,17 +263,14 @@ class MarianAttention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - assert attn_output.size() == ( - bsz * self.num_heads, - tgt_len, - self.head_dim, - ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) - attn_output = ( - attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - .transpose(1, 2) - .reshape(bsz, tgt_len, embed_dim) - ) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index dd76e65129..a445539be7 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -217,28 +217,26 @@ class MBartAttention(nn.Module): src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - assert attn_weights.size() == ( - bsz * self.num_heads, - tgt_len, - src_len, - ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) if attention_mask is not None: - assert attention_mask.size() == ( - bsz, - 1, - tgt_len, - src_len, - ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if layer_head_mask is not None: - assert layer_head_mask.size() == ( - self.num_heads, - ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -256,17 +254,14 @@ class MBartAttention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - assert attn_output.size() == ( - bsz * self.num_heads, - tgt_len, - self.head_dim, - ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) - attn_output = ( - attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - .transpose(1, 2) - .reshape(bsz, tgt_len, embed_dim) - ) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 66a15964e6..e43a0bcbb4 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -226,28 +226,26 @@ class PegasusAttention(nn.Module): src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - assert attn_weights.size() == ( - bsz * self.num_heads, - tgt_len, - src_len, - ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) if attention_mask is not None: - assert attention_mask.size() == ( - bsz, - 1, - tgt_len, - src_len, - ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if layer_head_mask is not None: - assert layer_head_mask.size() == ( - self.num_heads, - ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -265,17 +263,14 @@ class PegasusAttention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - assert attn_output.size() == ( - bsz * self.num_heads, - tgt_len, - self.head_dim, - ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) - attn_output = ( - attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - .transpose(1, 2) - .reshape(bsz, tgt_len, embed_dim) - ) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index ff50202b35..3bd21831c9 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -293,28 +293,26 @@ class Speech2TextAttention(nn.Module): src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - assert attn_weights.size() == ( - bsz * self.num_heads, - tgt_len, - src_len, - ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) if attention_mask is not None: - assert attention_mask.size() == ( - bsz, - 1, - tgt_len, - src_len, - ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if layer_head_mask is not None: - assert layer_head_mask.size() == ( - self.num_heads, - ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -332,17 +330,14 @@ class Speech2TextAttention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - assert attn_output.size() == ( - bsz * self.num_heads, - tgt_len, - self.head_dim, - ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) - attn_output = ( - attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - .transpose(1, 2) - .reshape(bsz, tgt_len, embed_dim) - ) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 98123bdd31..e55e6179ed 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -356,28 +356,26 @@ class Wav2Vec2Attention(nn.Module): src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - assert attn_weights.size() == ( - bsz * self.num_heads, - tgt_len, - src_len, - ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) if attention_mask is not None: - assert attention_mask.size() == ( - bsz, - 1, - tgt_len, - src_len, - ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if layer_head_mask is not None: - assert layer_head_mask.size() == ( - self.num_heads, - ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -395,17 +393,14 @@ class Wav2Vec2Attention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - assert attn_output.size() == ( - bsz * self.num_heads, - tgt_len, - self.head_dim, - ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) - attn_output = ( - attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - .transpose(1, 2) - .reshape(bsz, tgt_len, embed_dim) - ) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 1e6f833a21..1d78af6d90 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -1721,28 +1721,26 @@ class {{cookiecutter.camelcase_modelname}}Attention(nn.Module): src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - assert attn_weights.size() == ( - bsz * self.num_heads, - tgt_len, - src_len, - ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) if attention_mask is not None: - assert attention_mask.size() == ( - bsz, - 1, - tgt_len, - src_len, - ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if layer_head_mask is not None: - assert layer_head_mask.size() == ( - self.num_heads, - ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -1760,17 +1758,14 @@ class {{cookiecutter.camelcase_modelname}}Attention(nn.Module): attn_output = torch.bmm(attn_probs, value_states) - assert attn_output.size() == ( - bsz * self.num_heads, - tgt_len, - self.head_dim, - ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) - attn_output = ( - attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - .transpose(1, 2) - .reshape(bsz, tgt_len, embed_dim) - ) + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = self.out_proj(attn_output)