Re-styling in seq2seq attention (#11613)
This commit is contained in:
parent
cf409e5594
commit
7eee950ac3
|
@ -210,28 +210,26 @@ class BartAttention(nn.Module):
|
||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
@ -249,17 +247,14 @@ class BartAttention(nn.Module):
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
)
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
.transpose(1, 2)
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|
|
@ -211,28 +211,26 @@ class BlenderbotAttention(nn.Module):
|
||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
@ -250,17 +248,14 @@ class BlenderbotAttention(nn.Module):
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
)
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
.transpose(1, 2)
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|
|
@ -209,28 +209,26 @@ class BlenderbotSmallAttention(nn.Module):
|
||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
@ -248,17 +246,14 @@ class BlenderbotSmallAttention(nn.Module):
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
)
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
.transpose(1, 2)
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|
|
@ -280,28 +280,26 @@ class M2M100Attention(nn.Module):
|
||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
@ -319,17 +317,14 @@ class M2M100Attention(nn.Module):
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
)
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
.transpose(1, 2)
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|
|
@ -226,28 +226,26 @@ class MarianAttention(nn.Module):
|
||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
@ -265,17 +263,14 @@ class MarianAttention(nn.Module):
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
)
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
.transpose(1, 2)
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|
|
@ -217,28 +217,26 @@ class MBartAttention(nn.Module):
|
||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
@ -256,17 +254,14 @@ class MBartAttention(nn.Module):
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
)
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
.transpose(1, 2)
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|
|
@ -226,28 +226,26 @@ class PegasusAttention(nn.Module):
|
||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
@ -265,17 +263,14 @@ class PegasusAttention(nn.Module):
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
)
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
.transpose(1, 2)
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|
|
@ -293,28 +293,26 @@ class Speech2TextAttention(nn.Module):
|
||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
@ -332,17 +330,14 @@ class Speech2TextAttention(nn.Module):
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
)
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
.transpose(1, 2)
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|
|
@ -356,28 +356,26 @@ class Wav2Vec2Attention(nn.Module):
|
||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
@ -395,17 +393,14 @@ class Wav2Vec2Attention(nn.Module):
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
)
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
.transpose(1, 2)
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|
|
@ -1721,28 +1721,26 @@ class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||||
src_len = key_states.size(1)
|
src_len = key_states.size(1)
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
src_len,
|
)
|
||||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.size() == (
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
bsz,
|
raise ValueError(
|
||||||
1,
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
tgt_len,
|
)
|
||||||
src_len,
|
|
||||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
if layer_head_mask is not None:
|
if layer_head_mask is not None:
|
||||||
assert layer_head_mask.size() == (
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
self.num_heads,
|
raise ValueError(
|
||||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
@ -1760,17 +1758,14 @@ class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
assert attn_output.size() == (
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
bsz * self.num_heads,
|
raise ValueError(
|
||||||
tgt_len,
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
self.head_dim,
|
)
|
||||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
.transpose(1, 2)
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||||
.reshape(bsz, tgt_len, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue