Prophetnet batch dimension inversion fix (#21870)

* decoder forward pass is working

* no model has forward pass returning attentions

* decoder ngram changed to not mix batch size

* current basic forward pass returns identical result

* passed test_model attentions

* passed test_encoder_decoder_model_generate

* passed test_headmasking

* removed old block

* removed comments bug/fixme

* removed bug comments

* applied styling

* applied fix-copies

* applied ngram forward comments

* corrected dimension notation

* applied styling and comment fixes

* changed asserts for raise ValueError

* changed question gen test

* updated hidden_states integration test

* applied styling
This commit is contained in:
Kian Sierra McGettigan 2023-03-02 18:07:45 +01:00 committed by GitHub
parent 99ba36e72f
commit 6bf885375a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 262 additions and 314 deletions

View File

@ -701,44 +701,27 @@ class ProphetNetAttention(nn.Module):
past_key_value = (key_states, value_states)
# project states into the correct shape
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
assert attn_weights.size() == (
batch_size * self.num_attn_heads,
tgt_len,
src_len,
), (
f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size"
f" {attn_weights.shape}"
)
src_len = key_states.size(2)
attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3))
expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len)
if attn_weights.size() != expected_shape:
raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}")
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if attention_mask is not None and attention_mask.dim() == 0:
attention_mask = None
assert attention_mask is None or attention_mask.size() == (
self.num_attn_heads * batch_size,
1,
src_len,
), (
"`attention_mask` should be `None` or of shape attention_mask.size() =="
f" {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
)
expected_shape = (batch_size, self.num_attn_heads, 1, src_len)
if attention_mask is not None and attention_mask.size() != expected_shape:
raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}")
if attention_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights + attention_mask
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(batch_size, self.num_attn_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(batch_size * self.num_attn_heads, tgt_len, src_len)
attn_weights_reshaped = attn_weights
else:
attn_weights_reshaped = None
@ -752,7 +735,6 @@ class ProphetNetAttention(nn.Module):
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
batch_size, self.num_attn_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(batch_size * self.num_attn_heads, tgt_len, src_len)
# apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped
@ -762,23 +744,12 @@ class ProphetNetAttention(nn.Module):
p=self.attention_dropout,
training=self.training,
)
attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states)
expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim)
if attn_output.size() != expected_shape:
raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}")
attn_output = torch.bmm(attn_probs, value_states)
assert attn_output.size() == (
batch_size * self.num_attn_heads,
tgt_len,
self.head_dim,
), (
f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of"
f" shape {attn_output.size()}"
)
attn_output = (
attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, tgt_len, hidden_size)
)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size)
attn_output = self.out_proj(attn_output)
attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
@ -856,7 +827,6 @@ class ProphetNetNgramSelfAttention(nn.Module):
position_ids=None,
):
batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
f" {hidden_states.shape}"
@ -874,8 +844,7 @@ class ProphetNetNgramSelfAttention(nn.Module):
query_states = self._shape(query_states, ngram_sequence_length, batch_size)
key_states = self._shape(key_states, -1, batch_size)
value_states = self._shape(value_states, -1, batch_size)
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
query_states = query_states.view(*proj_shape)
key_states = key_states.view(*proj_shape)
@ -883,10 +852,9 @@ class ProphetNetNgramSelfAttention(nn.Module):
# chunk into main stream and predict stream
hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
query_states_list = query_states.chunk(1 + self.ngram, dim=1)
key_states_list = key_states.chunk(1 + self.ngram, dim=1)
value_states_list = value_states.chunk(1 + self.ngram, dim=1)
query_states_list = query_states.chunk(1 + self.ngram, dim=2)
key_states_list = key_states.chunk(1 + self.ngram, dim=2)
value_states_list = value_states.chunk(1 + self.ngram, dim=2)
main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:]
main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:]
@ -895,28 +863,29 @@ class ProphetNetNgramSelfAttention(nn.Module):
# saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim)
if past_key_value is not None:
prev_main_key_states = past_key_value[0].view(batch_size * self.num_attn_heads, -1, self.head_dim)
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1)
prev_main_value_states = past_key_value[1].view(batch_size * self.num_attn_heads, -1, self.head_dim)
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1)
prev_main_key_states = past_key_value[0]
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2)
prev_main_value_states = past_key_value[1]
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2)
# Update cache
past_key_value = (
main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
)
past_key_value = (main_key_states, main_value_states)
# get seq_length of main stream only
sequence_length = ngram_sequence_length // (1 + self.ngram)
# MAIN-STREAM
# main attn weights
main_attn_weights = torch.bmm(main_query_states, main_key_states.transpose(1, 2))
# [batch_size, number_heads, sequence_length, head_dimesion]
# x [batch_size, number_heads, head_dimesion, sequence_length]
# -> [batch_size, number_heads, sequence_length, sequence_length]
main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3))
# retrieve relative position embeddings for each layer -> see paper for more details
main_relative_pos_embeddings = self.get_main_relative_pos_embeddings(
main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets
)
main_attn_weights = main_attn_weights + main_relative_pos_embeddings
if attention_mask is not None:
@ -936,55 +905,53 @@ class ProphetNetNgramSelfAttention(nn.Module):
main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
batch_size, self.num_attn_heads, -1, sequence_length
)
main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length)
main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
# project to attn_output
main_attn_output = torch.bmm(main_attn_probs, main_value_states)
# [batch_size, number_heads, sequence_length, sequence_length]
# x [batch_size, number_heads, sequence_length, head_dimesion]
# -> [batch_size, number_heads, sequence_length, head_dimesion]
main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states)
# reshape so that num_heads dim is merged into last `head_dim` axis
main_attn_output = (
main_attn_output.view(batch_size, self.num_attn_heads, sequence_length, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, 1, sequence_length, hidden_size)
)
main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size)
main_attn_output = self.out_proj(main_attn_output)
# PREDICT-STREAM
# [ngram, B*head, T, c]
predict_query_states = torch.cat(predict_query_states_list, 0).view(
self.ngram, -1, sequence_length, self.head_dim
)
# [ngram, B*head, 2*T, c]
predict_key_states = torch.cat(
[torch.cat([main_key_states, key], 1).unsqueeze(0) for key in predict_key_states_list], 0
# [batch_size, ngram, number_heads, sequence_length, head_dimesion]
predict_query_states = torch.stack(predict_query_states_list, 1).view(
batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim
)
# [ngram, T, B, C]
predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view(
self.ngram, sequence_length, batch_size, hidden_size
)
# [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1)
# [ngram, B*head, 2*T, c]
# [batch_size, sequence_length, ngram, hidden_size]
predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2)
# [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion]
predict_value_states = torch.cat(
[torch.cat([main_value_states, v_p], 1).unsqueeze(0) for v_p in predict_value_states_list], 0
[torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2
)
# [ngram, B*head, T, 2*T]
predict_attn_weights = torch.einsum("nbtc,nbsc->nbts", (predict_query_states, predict_key_states))
# [ngram, B*head, T, S]
# [batch_size, ngram, number_heads, sequence_length, head_dimesion]
# x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
# -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states))
# retrieve relative position embeddings for each layer -> see paper for more details
# [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings]
predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(
predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets
)
# [ngram, B*head, T, 2*T]
# [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
if extended_predict_attention_mask is not None:
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask.to(
predict_attn_weights.dtype
)
# Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4)
extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype)
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask
predict_attn_probs = softmax(
predict_attn_weights,
@ -997,37 +964,30 @@ class ProphetNetNgramSelfAttention(nn.Module):
f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
f" {layer_head_mask.size()}"
)
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length
)
predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length
)
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs
predict_attn_probs = nn.functional.dropout(
predict_attn_probs, p=self.attention_dropout, training=self.training
)
# project to attention output
# [ngram, B*head, T, c]
predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states))
# [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
# x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
# -> [batch_size, ngram, number_heads, sequence_length, head_dimesion]
predict_attn_output = torch.einsum(
"bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2))
)
# reshape so that num_heads dim is merged into last `head_dim` axis
# [ngram, B, T, C]
predict_attn_output = (
predict_attn_output.view(self.ngram, batch_size, self.num_attn_heads, sequence_length, self.head_dim)
.permute(1, 0, 3, 2, 4)
.reshape(batch_size, self.ngram, sequence_length, hidden_size)
)
# [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size]
predict_attn_output = predict_attn_output.transpose(2, 3)
predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size)
predict_attn_output = self.out_proj(predict_attn_output)
# concat to single attn output
# [B, 1+ngram*T, C]
# [batch_size, (1+ngram)*sequence_length, hidden_size]
attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
# reshape into better form for `config.output_attentions`
main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, sequence_length, -1
).transpose(0, 1)
attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
@ -1036,8 +996,11 @@ class ProphetNetNgramSelfAttention(nn.Module):
def get_main_relative_pos_embeddings(
self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
):
# input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1]
# input hidden_states [batch_size, sequence_length, hidden_size]
# input attn_weights [batch_size, num_heads, sequence_length, sequence_length]
# input position_ids [batch_size, sequence_length] or [1,1]
batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape
attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len)
if main_relative_position_buckets is None:
batch_size, sequence_length = hidden_states.shape[:2]
relative_positions = (
@ -1047,39 +1010,42 @@ class ProphetNetNgramSelfAttention(nn.Module):
.repeat(batch_size, sequence_length, 1)
.to(position_ids.device)
)
relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(
batch_size, sequence_length, 1
) # [B, T, s]
# [batch_size, sequence_length, sequence_length+1]
relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
main_relative_position_buckets = compute_relative_buckets(
self.num_buckets, self.relative_max_distance, relative_positions, False
)
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head]
# [batch_size, sequence_length, num_buckets * num_heads]
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
rel_pos_embeddings = rel_pos_embeddings.view(
rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
).permute(
0, 3, 1, 2
) # [B,T,Buckets,head]
rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:2] + (-1,)) # [B*head,T,Buckets]
)
rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2)
# [batch_size, num_heads, sequence_length, num_buckets]
rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,))
main_relative_position_buckets = (
main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
.view(-1, main_relative_position_buckets.shape[-1])
.long()
) # [B*head*T, T]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) # [B*head*T,Buckets]
main_relative_pos_embeddings = torch.gather(
rel_pos_embeddings, dim=1, index=main_relative_position_buckets
).view(attn_weights.shape[:2] + (-1,))
main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
# [batch_size * num_heads * sequence_length, sequence_length]
main_relative_position_buckets = main_relative_position_buckets.view(
-1, main_relative_position_buckets.shape[-1]
)
main_relative_position_buckets = main_relative_position_buckets.long()
# [batch_size * num_heads * sequence_length, sequence_length]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets)
main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1)
return main_relative_pos_embeddings
def get_predict_relative_pos_embeddings(
self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
):
# input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None
sequence_length, batch_size = hidden_states.shape[1:3]
# input hidden_states [batch_size, sequence_length, ngram, hidden_size]
# input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length]
# input position_ids [batch_size, sequence_length] or [1,1]
# input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None
batch_size, sequence_length = hidden_states.shape[0:2]
if predict_relative_position_buckets is None:
key_sequence_length = attn_weights.shape[-1]
@ -1099,28 +1065,35 @@ class ProphetNetNgramSelfAttention(nn.Module):
self.num_buckets, self.relative_max_distance, relative_positions, False
)
hidden_states = hidden_states.transpose(1, 2) # [ngram, B, T, C]
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states).view(
# [batch_size, ngram, sequence_length, hidden_size]
hidden_states = hidden_states.transpose(1, 2)
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
# [batch_size, ngram, sequence_length, num_buckets, num_heads]
rel_pos_embeddings = rel_pos_embeddings.view(
hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads)
) # [ngram, B, T, bucket, head]
rel_pos_embeddings = rel_pos_embeddings.permute(0, 1, 4, 2, 3).reshape(
self.ngram * batch_size * self.num_attn_heads, sequence_length, -1
) # [ngram*B*head, T, bucket]
predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0).repeat(
)
rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3)
# [batch_size * ngram * sequence_length * num_heads, num_buckets]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets)
# [ngram, batch_size, num_heads * sequence_length, -1]
predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0)
predict_relative_position_buckets = predict_relative_position_buckets.repeat(
self.ngram, 1, self.num_attn_heads, 1
) # [ngram, B, head*T, S]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
)
# [ngram * batch_size * num_heads * sequence_length, -1]
predict_relative_position_buckets = predict_relative_position_buckets.view(
-1, predict_relative_position_buckets.size(-1)
).long() # [ngram*B*head*T, S]
).long()
predict_relative_pos_embeddings = torch.gather(
rel_pos_embeddings, dim=1, index=predict_relative_position_buckets
).view(
self.ngram, batch_size * self.num_attn_heads, sequence_length, -1
) # [ngram, B*head, T, S]
)
# [batch_size, gram, num_heads, sequence_length, -1]
predict_relative_pos_embeddings = predict_relative_pos_embeddings.view(
batch_size, self.ngram, self.num_attn_heads, sequence_length, -1
)
return predict_relative_pos_embeddings
@ -1331,7 +1304,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
# prepare attention mask
if attention_mask is not None:
extended_attention_mask = (
1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1)
1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1)
) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
else:
@ -1549,7 +1522,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
# prepare encoder attention mask
if encoder_attention_mask is not None:
extended_encoder_attention_mask = (
1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1)
1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1)
) * torch.finfo(self.dtype).min
extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
else:
@ -1717,17 +1690,18 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
device=hidden_states.device,
)
causal_mask = torch.triu(causal_mask, 1)
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand(
(batch_size,) + causal_mask.shape
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand(
(batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape
)
# add usual attention mask
if attention_mask is not None:
extended_attention_mask = (1.0 - attention_mask[:, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_causal_mask + extended_attention_mask
else:
extended_attention_mask = extended_causal_mask
return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype)
return extended_attention_mask.to(hidden_states.dtype)
def prepare_predict_attention_mask(self, hidden_states, attention_mask):
batch_size, seq_length = hidden_states.shape[:2]
@ -1745,14 +1719,16 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
],
dim=-1,
)
extended_predict_causal_mask = predict_causal_mask[:, None, :, :].expand(
predict_causal_mask.shape[:1] + (batch_size,) + predict_causal_mask.shape[1:]
extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand(
(batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape
)
# add usual attention mask
if attention_mask is not None:
extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length))
extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.expand(
(batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length)
)
# predicted stream attention_mask should always be 0
extended_attention_mask = torch.cat(
[extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1
@ -1760,9 +1736,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask
else:
extended_predict_attention_mask = extended_predict_causal_mask
return extended_predict_attention_mask.repeat(1, self.config.num_decoder_attention_heads, 1, 1).to(
hidden_states.dtype
)
return extended_predict_attention_mask.to(hidden_states.dtype)
@add_start_docstrings(

View File

@ -716,44 +716,27 @@ class XLMProphetNetAttention(nn.Module):
past_key_value = (key_states, value_states)
# project states into the correct shape
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
assert attn_weights.size() == (
batch_size * self.num_attn_heads,
tgt_len,
src_len,
), (
f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size"
f" {attn_weights.shape}"
)
src_len = key_states.size(2)
attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3))
expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len)
if attn_weights.size() != expected_shape:
raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}")
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if attention_mask is not None and attention_mask.dim() == 0:
attention_mask = None
assert attention_mask is None or attention_mask.size() == (
self.num_attn_heads * batch_size,
1,
src_len,
), (
"`attention_mask` should be `None` or of shape attention_mask.size() =="
f" {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
)
expected_shape = (batch_size, self.num_attn_heads, 1, src_len)
if attention_mask is not None and attention_mask.size() != expected_shape:
raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}")
if attention_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights + attention_mask
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(batch_size, self.num_attn_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(batch_size * self.num_attn_heads, tgt_len, src_len)
attn_weights_reshaped = attn_weights
else:
attn_weights_reshaped = None
@ -767,7 +750,6 @@ class XLMProphetNetAttention(nn.Module):
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
batch_size, self.num_attn_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(batch_size * self.num_attn_heads, tgt_len, src_len)
# apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped
@ -777,23 +759,12 @@ class XLMProphetNetAttention(nn.Module):
p=self.attention_dropout,
training=self.training,
)
attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states)
expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim)
if attn_output.size() != expected_shape:
raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}")
attn_output = torch.bmm(attn_probs, value_states)
assert attn_output.size() == (
batch_size * self.num_attn_heads,
tgt_len,
self.head_dim,
), (
f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of"
f" shape {attn_output.size()}"
)
attn_output = (
attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, tgt_len, hidden_size)
)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size)
attn_output = self.out_proj(attn_output)
attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
@ -873,7 +844,6 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
position_ids=None,
):
batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
f" {hidden_states.shape}"
@ -891,8 +861,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
query_states = self._shape(query_states, ngram_sequence_length, batch_size)
key_states = self._shape(key_states, -1, batch_size)
value_states = self._shape(value_states, -1, batch_size)
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
query_states = query_states.view(*proj_shape)
key_states = key_states.view(*proj_shape)
@ -900,10 +869,9 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
# chunk into main stream and predict stream
hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
query_states_list = query_states.chunk(1 + self.ngram, dim=1)
key_states_list = key_states.chunk(1 + self.ngram, dim=1)
value_states_list = value_states.chunk(1 + self.ngram, dim=1)
query_states_list = query_states.chunk(1 + self.ngram, dim=2)
key_states_list = key_states.chunk(1 + self.ngram, dim=2)
value_states_list = value_states.chunk(1 + self.ngram, dim=2)
main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:]
main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:]
@ -912,28 +880,29 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
# saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim)
if past_key_value is not None:
prev_main_key_states = past_key_value[0].view(batch_size * self.num_attn_heads, -1, self.head_dim)
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1)
prev_main_value_states = past_key_value[1].view(batch_size * self.num_attn_heads, -1, self.head_dim)
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1)
prev_main_key_states = past_key_value[0]
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2)
prev_main_value_states = past_key_value[1]
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2)
# Update cache
past_key_value = (
main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
)
past_key_value = (main_key_states, main_value_states)
# get seq_length of main stream only
sequence_length = ngram_sequence_length // (1 + self.ngram)
# MAIN-STREAM
# main attn weights
main_attn_weights = torch.bmm(main_query_states, main_key_states.transpose(1, 2))
# [batch_size, number_heads, sequence_length, head_dimesion]
# x [batch_size, number_heads, head_dimesion, sequence_length]
# -> [batch_size, number_heads, sequence_length, sequence_length]
main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3))
# retrieve relative position embeddings for each layer -> see paper for more details
main_relative_pos_embeddings = self.get_main_relative_pos_embeddings(
main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets
)
main_attn_weights = main_attn_weights + main_relative_pos_embeddings
if attention_mask is not None:
@ -953,55 +922,53 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
batch_size, self.num_attn_heads, -1, sequence_length
)
main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length)
main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
# project to attn_output
main_attn_output = torch.bmm(main_attn_probs, main_value_states)
# [batch_size, number_heads, sequence_length, sequence_length]
# x [batch_size, number_heads, sequence_length, head_dimesion]
# -> [batch_size, number_heads, sequence_length, head_dimesion]
main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states)
# reshape so that num_heads dim is merged into last `head_dim` axis
main_attn_output = (
main_attn_output.view(batch_size, self.num_attn_heads, sequence_length, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, 1, sequence_length, hidden_size)
)
main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size)
main_attn_output = self.out_proj(main_attn_output)
# PREDICT-STREAM
# [ngram, B*head, T, c]
predict_query_states = torch.cat(predict_query_states_list, 0).view(
self.ngram, -1, sequence_length, self.head_dim
)
# [ngram, B*head, 2*T, c]
predict_key_states = torch.cat(
[torch.cat([main_key_states, key], 1).unsqueeze(0) for key in predict_key_states_list], 0
# [batch_size, ngram, number_heads, sequence_length, head_dimesion]
predict_query_states = torch.stack(predict_query_states_list, 1).view(
batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim
)
# [ngram, T, B, C]
predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view(
self.ngram, sequence_length, batch_size, hidden_size
)
# [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1)
# [ngram, B*head, 2*T, c]
# [batch_size, sequence_length, ngram, hidden_size]
predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2)
# [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion]
predict_value_states = torch.cat(
[torch.cat([main_value_states, v_p], 1).unsqueeze(0) for v_p in predict_value_states_list], 0
[torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2
)
# [ngram, B*head, T, 2*T]
predict_attn_weights = torch.einsum("nbtc,nbsc->nbts", (predict_query_states, predict_key_states))
# [ngram, B*head, T, S]
# [batch_size, ngram, number_heads, sequence_length, head_dimesion]
# x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
# -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states))
# retrieve relative position embeddings for each layer -> see paper for more details
# [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings]
predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(
predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets
)
# [ngram, B*head, T, 2*T]
# [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
if extended_predict_attention_mask is not None:
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask.to(
predict_attn_weights.dtype
)
# Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4)
extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype)
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask
predict_attn_probs = softmax(
predict_attn_weights,
@ -1014,37 +981,30 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
f" {layer_head_mask.size()}"
)
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length
)
predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length
)
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs
predict_attn_probs = nn.functional.dropout(
predict_attn_probs, p=self.attention_dropout, training=self.training
)
# project to attention output
# [ngram, B*head, T, c]
predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states))
# [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
# x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
# -> [batch_size, ngram, number_heads, sequence_length, head_dimesion]
predict_attn_output = torch.einsum(
"bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2))
)
# reshape so that num_heads dim is merged into last `head_dim` axis
# [ngram, B, T, C]
predict_attn_output = (
predict_attn_output.view(self.ngram, batch_size, self.num_attn_heads, sequence_length, self.head_dim)
.permute(1, 0, 3, 2, 4)
.reshape(batch_size, self.ngram, sequence_length, hidden_size)
)
# [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size]
predict_attn_output = predict_attn_output.transpose(2, 3)
predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size)
predict_attn_output = self.out_proj(predict_attn_output)
# concat to single attn output
# [B, 1+ngram*T, C]
# [batch_size, (1+ngram)*sequence_length, hidden_size]
attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
# reshape into better form for `config.output_attentions`
main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, sequence_length, -1
).transpose(0, 1)
attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
@ -1053,8 +1013,11 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
def get_main_relative_pos_embeddings(
self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
):
# input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1]
# input hidden_states [batch_size, sequence_length, hidden_size]
# input attn_weights [batch_size, num_heads, sequence_length, sequence_length]
# input position_ids [batch_size, sequence_length] or [1,1]
batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape
attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len)
if main_relative_position_buckets is None:
batch_size, sequence_length = hidden_states.shape[:2]
relative_positions = (
@ -1064,39 +1027,42 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
.repeat(batch_size, sequence_length, 1)
.to(position_ids.device)
)
relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(
batch_size, sequence_length, 1
) # [B, T, s]
# [batch_size, sequence_length, sequence_length+1]
relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
main_relative_position_buckets = compute_relative_buckets(
self.num_buckets, self.relative_max_distance, relative_positions, False
)
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head]
# [batch_size, sequence_length, num_buckets * num_heads]
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
rel_pos_embeddings = rel_pos_embeddings.view(
rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
).permute(
0, 3, 1, 2
) # [B,T,Buckets,head]
rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:2] + (-1,)) # [B*head,T,Buckets]
)
rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2)
# [batch_size, num_heads, sequence_length, num_buckets]
rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,))
main_relative_position_buckets = (
main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
.view(-1, main_relative_position_buckets.shape[-1])
.long()
) # [B*head*T, T]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) # [B*head*T,Buckets]
main_relative_pos_embeddings = torch.gather(
rel_pos_embeddings, dim=1, index=main_relative_position_buckets
).view(attn_weights.shape[:2] + (-1,))
main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
# [batch_size * num_heads * sequence_length, sequence_length]
main_relative_position_buckets = main_relative_position_buckets.view(
-1, main_relative_position_buckets.shape[-1]
)
main_relative_position_buckets = main_relative_position_buckets.long()
# [batch_size * num_heads * sequence_length, sequence_length]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets)
main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1)
return main_relative_pos_embeddings
def get_predict_relative_pos_embeddings(
self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
):
# input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None
sequence_length, batch_size = hidden_states.shape[1:3]
# input hidden_states [batch_size, sequence_length, ngram, hidden_size]
# input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length]
# input position_ids [batch_size, sequence_length] or [1,1]
# input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None
batch_size, sequence_length = hidden_states.shape[0:2]
if predict_relative_position_buckets is None:
key_sequence_length = attn_weights.shape[-1]
@ -1116,28 +1082,35 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
self.num_buckets, self.relative_max_distance, relative_positions, False
)
hidden_states = hidden_states.transpose(1, 2) # [ngram, B, T, C]
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states).view(
# [batch_size, ngram, sequence_length, hidden_size]
hidden_states = hidden_states.transpose(1, 2)
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
# [batch_size, ngram, sequence_length, num_buckets, num_heads]
rel_pos_embeddings = rel_pos_embeddings.view(
hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads)
) # [ngram, B, T, bucket, head]
rel_pos_embeddings = rel_pos_embeddings.permute(0, 1, 4, 2, 3).reshape(
self.ngram * batch_size * self.num_attn_heads, sequence_length, -1
) # [ngram*B*head, T, bucket]
predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0).repeat(
)
rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3)
# [batch_size * ngram * sequence_length * num_heads, num_buckets]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets)
# [ngram, batch_size, num_heads * sequence_length, -1]
predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0)
predict_relative_position_buckets = predict_relative_position_buckets.repeat(
self.ngram, 1, self.num_attn_heads, 1
) # [ngram, B, head*T, S]
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
)
# [ngram * batch_size * num_heads * sequence_length, -1]
predict_relative_position_buckets = predict_relative_position_buckets.view(
-1, predict_relative_position_buckets.size(-1)
).long() # [ngram*B*head*T, S]
).long()
predict_relative_pos_embeddings = torch.gather(
rel_pos_embeddings, dim=1, index=predict_relative_position_buckets
).view(
self.ngram, batch_size * self.num_attn_heads, sequence_length, -1
) # [ngram, B*head, T, S]
)
# [batch_size, gram, num_heads, sequence_length, -1]
predict_relative_pos_embeddings = predict_relative_pos_embeddings.view(
batch_size, self.ngram, self.num_attn_heads, sequence_length, -1
)
return predict_relative_pos_embeddings
@ -1351,7 +1324,7 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel):
# prepare attention mask
if attention_mask is not None:
extended_attention_mask = (
1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1)
1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1)
) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
else:
@ -1572,7 +1545,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
# prepare encoder attention mask
if encoder_attention_mask is not None:
extended_encoder_attention_mask = (
1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1)
1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1)
) * torch.finfo(self.dtype).min
extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
else:
@ -1740,17 +1713,18 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
device=hidden_states.device,
)
causal_mask = torch.triu(causal_mask, 1)
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand(
(batch_size,) + causal_mask.shape
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand(
(batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape
)
# add usual attention mask
if attention_mask is not None:
extended_attention_mask = (1.0 - attention_mask[:, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_causal_mask + extended_attention_mask
else:
extended_attention_mask = extended_causal_mask
return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype)
return extended_attention_mask.to(hidden_states.dtype)
def prepare_predict_attention_mask(self, hidden_states, attention_mask):
batch_size, seq_length = hidden_states.shape[:2]
@ -1768,14 +1742,16 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
],
dim=-1,
)
extended_predict_causal_mask = predict_causal_mask[:, None, :, :].expand(
predict_causal_mask.shape[:1] + (batch_size,) + predict_causal_mask.shape[1:]
extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand(
(batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape
)
# add usual attention mask
if attention_mask is not None:
extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length))
extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min
extended_attention_mask = extended_attention_mask.expand(
(batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length)
)
# predicted stream attention_mask should always be 0
extended_attention_mask = torch.cat(
[extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1
@ -1783,9 +1759,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask
else:
extended_predict_attention_mask = extended_predict_causal_mask
return extended_predict_attention_mask.repeat(1, self.config.num_decoder_attention_heads, 1, 1).to(
hidden_states.dtype
)
return extended_predict_attention_mask.to(hidden_states.dtype)
@add_start_docstrings(

View File

@ -1206,7 +1206,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 12, 30522))
self.assertEqual(output_predited_logits.shape, expected_shape)
expected_slice = torch.tensor(
[[[-7.6213, -7.9008, -7.9979], [-7.6834, -7.8467, -8.2187], [-7.5326, -7.4762, -8.1914]]]
[[[-7.7729, -8.0343, -8.26001], [-7.74213, -7.8629, -8.6000], [-7.7328, -7.8269, -8.5264]]]
).to(torch_device)
# self.assertTrue(torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4))
assert torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4)
@ -1306,7 +1306,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
EXPECTED_QUESTIONS = [
"along with paul allen, who founded microsoft?",
"what year was microsoft founded?",
"on what date was microsoft founded?",
"when was microsoft founded?",
]
self.assertListEqual(