Prophetnet batch dimension inversion fix (#21870)
* decoder forward pass is working * no model has forward pass returning attentions * decoder ngram changed to not mix batch size * current basic forward pass returns identical result * passed test_model attentions * passed test_encoder_decoder_model_generate * passed test_headmasking * removed old block * removed comments bug/fixme * removed bug comments * applied styling * applied fix-copies * applied ngram forward comments * corrected dimension notation * applied styling and comment fixes * changed asserts for raise ValueError * changed question gen test * updated hidden_states integration test * applied styling
This commit is contained in:
parent
99ba36e72f
commit
6bf885375a
|
@ -701,44 +701,27 @@ class ProphetNetAttention(nn.Module):
|
|||
past_key_value = (key_states, value_states)
|
||||
|
||||
# project states into the correct shape
|
||||
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
assert attn_weights.size() == (
|
||||
batch_size * self.num_attn_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
), (
|
||||
f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size"
|
||||
f" {attn_weights.shape}"
|
||||
)
|
||||
src_len = key_states.size(2)
|
||||
attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3))
|
||||
expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len)
|
||||
if attn_weights.size() != expected_shape:
|
||||
raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}")
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
|
||||
if attention_mask is not None and attention_mask.dim() == 0:
|
||||
attention_mask = None
|
||||
assert attention_mask is None or attention_mask.size() == (
|
||||
self.num_attn_heads * batch_size,
|
||||
1,
|
||||
src_len,
|
||||
), (
|
||||
"`attention_mask` should be `None` or of shape attention_mask.size() =="
|
||||
f" {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
|
||||
)
|
||||
|
||||
expected_shape = (batch_size, self.num_attn_heads, 1, src_len)
|
||||
if attention_mask is not None and attention_mask.size() != expected_shape:
|
||||
raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}")
|
||||
if attention_mask is not None: # don't attend to padding symbols
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to be reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(batch_size, self.num_attn_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(batch_size * self.num_attn_heads, tgt_len, src_len)
|
||||
attn_weights_reshaped = attn_weights
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
|
@ -752,7 +735,6 @@ class ProphetNetAttention(nn.Module):
|
|||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
|
||||
batch_size, self.num_attn_heads, tgt_len, src_len
|
||||
)
|
||||
attn_weights = attn_weights.view(batch_size * self.num_attn_heads, tgt_len, src_len)
|
||||
|
||||
# apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
|
||||
attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped
|
||||
|
@ -762,23 +744,12 @@ class ProphetNetAttention(nn.Module):
|
|||
p=self.attention_dropout,
|
||||
training=self.training,
|
||||
)
|
||||
attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states)
|
||||
expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim)
|
||||
if attn_output.size() != expected_shape:
|
||||
raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}")
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
assert attn_output.size() == (
|
||||
batch_size * self.num_attn_heads,
|
||||
tgt_len,
|
||||
self.head_dim,
|
||||
), (
|
||||
f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of"
|
||||
f" shape {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = (
|
||||
attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(batch_size, tgt_len, hidden_size)
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
|
||||
|
@ -856,7 +827,6 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
position_ids=None,
|
||||
):
|
||||
batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
|
||||
|
||||
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
|
||||
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
|
||||
f" {hidden_states.shape}"
|
||||
|
@ -874,8 +844,7 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
query_states = self._shape(query_states, ngram_sequence_length, batch_size)
|
||||
key_states = self._shape(key_states, -1, batch_size)
|
||||
value_states = self._shape(value_states, -1, batch_size)
|
||||
|
||||
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
|
||||
|
||||
query_states = query_states.view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
|
@ -883,10 +852,9 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
|
||||
# chunk into main stream and predict stream
|
||||
hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
|
||||
|
||||
query_states_list = query_states.chunk(1 + self.ngram, dim=1)
|
||||
key_states_list = key_states.chunk(1 + self.ngram, dim=1)
|
||||
value_states_list = value_states.chunk(1 + self.ngram, dim=1)
|
||||
query_states_list = query_states.chunk(1 + self.ngram, dim=2)
|
||||
key_states_list = key_states.chunk(1 + self.ngram, dim=2)
|
||||
value_states_list = value_states.chunk(1 + self.ngram, dim=2)
|
||||
|
||||
main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:]
|
||||
main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:]
|
||||
|
@ -895,28 +863,29 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
|
||||
# saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim)
|
||||
if past_key_value is not None:
|
||||
prev_main_key_states = past_key_value[0].view(batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1)
|
||||
prev_main_value_states = past_key_value[1].view(batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1)
|
||||
prev_main_key_states = past_key_value[0]
|
||||
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2)
|
||||
prev_main_value_states = past_key_value[1]
|
||||
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2)
|
||||
|
||||
# Update cache
|
||||
past_key_value = (
|
||||
main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
|
||||
main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
|
||||
)
|
||||
past_key_value = (main_key_states, main_value_states)
|
||||
|
||||
# get seq_length of main stream only
|
||||
sequence_length = ngram_sequence_length // (1 + self.ngram)
|
||||
|
||||
# MAIN-STREAM
|
||||
# main attn weights
|
||||
main_attn_weights = torch.bmm(main_query_states, main_key_states.transpose(1, 2))
|
||||
# [batch_size, number_heads, sequence_length, head_dimesion]
|
||||
# x [batch_size, number_heads, head_dimesion, sequence_length]
|
||||
# -> [batch_size, number_heads, sequence_length, sequence_length]
|
||||
main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3))
|
||||
|
||||
# retrieve relative position embeddings for each layer -> see paper for more details
|
||||
main_relative_pos_embeddings = self.get_main_relative_pos_embeddings(
|
||||
main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets
|
||||
)
|
||||
|
||||
main_attn_weights = main_attn_weights + main_relative_pos_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
|
@ -936,55 +905,53 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
|
||||
batch_size, self.num_attn_heads, -1, sequence_length
|
||||
)
|
||||
main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length)
|
||||
|
||||
main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
|
||||
# project to attn_output
|
||||
main_attn_output = torch.bmm(main_attn_probs, main_value_states)
|
||||
|
||||
# [batch_size, number_heads, sequence_length, sequence_length]
|
||||
# x [batch_size, number_heads, sequence_length, head_dimesion]
|
||||
# -> [batch_size, number_heads, sequence_length, head_dimesion]
|
||||
main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states)
|
||||
# reshape so that num_heads dim is merged into last `head_dim` axis
|
||||
main_attn_output = (
|
||||
main_attn_output.view(batch_size, self.num_attn_heads, sequence_length, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(batch_size, 1, sequence_length, hidden_size)
|
||||
)
|
||||
main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size)
|
||||
main_attn_output = self.out_proj(main_attn_output)
|
||||
|
||||
# PREDICT-STREAM
|
||||
# [ngram, B*head, T, c]
|
||||
predict_query_states = torch.cat(predict_query_states_list, 0).view(
|
||||
self.ngram, -1, sequence_length, self.head_dim
|
||||
)
|
||||
# [ngram, B*head, 2*T, c]
|
||||
predict_key_states = torch.cat(
|
||||
[torch.cat([main_key_states, key], 1).unsqueeze(0) for key in predict_key_states_list], 0
|
||||
# [batch_size, ngram, number_heads, sequence_length, head_dimesion]
|
||||
predict_query_states = torch.stack(predict_query_states_list, 1).view(
|
||||
batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim
|
||||
)
|
||||
|
||||
# [ngram, T, B, C]
|
||||
predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view(
|
||||
self.ngram, sequence_length, batch_size, hidden_size
|
||||
)
|
||||
# [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
|
||||
predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1)
|
||||
|
||||
# [ngram, B*head, 2*T, c]
|
||||
# [batch_size, sequence_length, ngram, hidden_size]
|
||||
predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2)
|
||||
|
||||
# [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion]
|
||||
predict_value_states = torch.cat(
|
||||
[torch.cat([main_value_states, v_p], 1).unsqueeze(0) for v_p in predict_value_states_list], 0
|
||||
[torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2
|
||||
)
|
||||
# [ngram, B*head, T, 2*T]
|
||||
predict_attn_weights = torch.einsum("nbtc,nbsc->nbts", (predict_query_states, predict_key_states))
|
||||
|
||||
# [ngram, B*head, T, S]
|
||||
# [batch_size, ngram, number_heads, sequence_length, head_dimesion]
|
||||
# x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
|
||||
# -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
|
||||
predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states))
|
||||
|
||||
# retrieve relative position embeddings for each layer -> see paper for more details
|
||||
# [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings]
|
||||
predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(
|
||||
predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets
|
||||
)
|
||||
|
||||
# [ngram, B*head, T, 2*T]
|
||||
# [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
|
||||
predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
|
||||
|
||||
if extended_predict_attention_mask is not None:
|
||||
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask.to(
|
||||
predict_attn_weights.dtype
|
||||
)
|
||||
# Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
|
||||
extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4)
|
||||
extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype)
|
||||
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask
|
||||
|
||||
predict_attn_probs = softmax(
|
||||
predict_attn_weights,
|
||||
|
@ -997,37 +964,30 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view(
|
||||
self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length
|
||||
)
|
||||
predict_attn_probs = predict_attn_probs.view(
|
||||
self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length
|
||||
)
|
||||
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs
|
||||
|
||||
predict_attn_probs = nn.functional.dropout(
|
||||
predict_attn_probs, p=self.attention_dropout, training=self.training
|
||||
)
|
||||
# project to attention output
|
||||
# [ngram, B*head, T, c]
|
||||
predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states))
|
||||
# [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
|
||||
# x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
|
||||
# -> [batch_size, ngram, number_heads, sequence_length, head_dimesion]
|
||||
predict_attn_output = torch.einsum(
|
||||
"bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2))
|
||||
)
|
||||
|
||||
# reshape so that num_heads dim is merged into last `head_dim` axis
|
||||
# [ngram, B, T, C]
|
||||
predict_attn_output = (
|
||||
predict_attn_output.view(self.ngram, batch_size, self.num_attn_heads, sequence_length, self.head_dim)
|
||||
.permute(1, 0, 3, 2, 4)
|
||||
.reshape(batch_size, self.ngram, sequence_length, hidden_size)
|
||||
)
|
||||
# [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size]
|
||||
predict_attn_output = predict_attn_output.transpose(2, 3)
|
||||
predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size)
|
||||
predict_attn_output = self.out_proj(predict_attn_output)
|
||||
|
||||
# concat to single attn output
|
||||
# [B, 1+ngram*T, C]
|
||||
# [batch_size, (1+ngram)*sequence_length, hidden_size]
|
||||
attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
|
||||
# reshape into better form for `config.output_attentions`
|
||||
main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
|
||||
predict_attn_probs = predict_attn_probs.view(
|
||||
self.ngram, batch_size, self.num_attn_heads, sequence_length, -1
|
||||
).transpose(0, 1)
|
||||
|
||||
attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
|
||||
|
||||
|
@ -1036,8 +996,11 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
def get_main_relative_pos_embeddings(
|
||||
self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
|
||||
):
|
||||
# input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1]
|
||||
|
||||
# input hidden_states [batch_size, sequence_length, hidden_size]
|
||||
# input attn_weights [batch_size, num_heads, sequence_length, sequence_length]
|
||||
# input position_ids [batch_size, sequence_length] or [1,1]
|
||||
batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape
|
||||
attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len)
|
||||
if main_relative_position_buckets is None:
|
||||
batch_size, sequence_length = hidden_states.shape[:2]
|
||||
relative_positions = (
|
||||
|
@ -1047,39 +1010,42 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
.repeat(batch_size, sequence_length, 1)
|
||||
.to(position_ids.device)
|
||||
)
|
||||
relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(
|
||||
batch_size, sequence_length, 1
|
||||
) # [B, T, s]
|
||||
# [batch_size, sequence_length, sequence_length+1]
|
||||
relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
|
||||
main_relative_position_buckets = compute_relative_buckets(
|
||||
self.num_buckets, self.relative_max_distance, relative_positions, False
|
||||
)
|
||||
|
||||
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head]
|
||||
# [batch_size, sequence_length, num_buckets * num_heads]
|
||||
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
|
||||
rel_pos_embeddings = rel_pos_embeddings.view(
|
||||
rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
|
||||
).permute(
|
||||
0, 3, 1, 2
|
||||
) # [B,T,Buckets,head]
|
||||
rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:2] + (-1,)) # [B*head,T,Buckets]
|
||||
)
|
||||
rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2)
|
||||
# [batch_size, num_heads, sequence_length, num_buckets]
|
||||
rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,))
|
||||
|
||||
main_relative_position_buckets = (
|
||||
main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
|
||||
.view(-1, main_relative_position_buckets.shape[-1])
|
||||
.long()
|
||||
) # [B*head*T, T]
|
||||
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) # [B*head*T,Buckets]
|
||||
|
||||
main_relative_pos_embeddings = torch.gather(
|
||||
rel_pos_embeddings, dim=1, index=main_relative_position_buckets
|
||||
).view(attn_weights.shape[:2] + (-1,))
|
||||
main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
|
||||
# [batch_size * num_heads * sequence_length, sequence_length]
|
||||
main_relative_position_buckets = main_relative_position_buckets.view(
|
||||
-1, main_relative_position_buckets.shape[-1]
|
||||
)
|
||||
main_relative_position_buckets = main_relative_position_buckets.long()
|
||||
# [batch_size * num_heads * sequence_length, sequence_length]
|
||||
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
|
||||
|
||||
main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets)
|
||||
main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1)
|
||||
return main_relative_pos_embeddings
|
||||
|
||||
def get_predict_relative_pos_embeddings(
|
||||
self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
|
||||
):
|
||||
# input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None
|
||||
sequence_length, batch_size = hidden_states.shape[1:3]
|
||||
# input hidden_states [batch_size, sequence_length, ngram, hidden_size]
|
||||
# input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length]
|
||||
# input position_ids [batch_size, sequence_length] or [1,1]
|
||||
# input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None
|
||||
batch_size, sequence_length = hidden_states.shape[0:2]
|
||||
|
||||
if predict_relative_position_buckets is None:
|
||||
key_sequence_length = attn_weights.shape[-1]
|
||||
|
@ -1099,28 +1065,35 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
self.num_buckets, self.relative_max_distance, relative_positions, False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2) # [ngram, B, T, C]
|
||||
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states).view(
|
||||
# [batch_size, ngram, sequence_length, hidden_size]
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
|
||||
|
||||
# [batch_size, ngram, sequence_length, num_buckets, num_heads]
|
||||
rel_pos_embeddings = rel_pos_embeddings.view(
|
||||
hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads)
|
||||
) # [ngram, B, T, bucket, head]
|
||||
rel_pos_embeddings = rel_pos_embeddings.permute(0, 1, 4, 2, 3).reshape(
|
||||
self.ngram * batch_size * self.num_attn_heads, sequence_length, -1
|
||||
) # [ngram*B*head, T, bucket]
|
||||
|
||||
predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0).repeat(
|
||||
)
|
||||
rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3)
|
||||
# [batch_size * ngram * sequence_length * num_heads, num_buckets]
|
||||
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets)
|
||||
# [ngram, batch_size, num_heads * sequence_length, -1]
|
||||
predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0)
|
||||
predict_relative_position_buckets = predict_relative_position_buckets.repeat(
|
||||
self.ngram, 1, self.num_attn_heads, 1
|
||||
) # [ngram, B, head*T, S]
|
||||
|
||||
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
|
||||
)
|
||||
# [ngram * batch_size * num_heads * sequence_length, -1]
|
||||
predict_relative_position_buckets = predict_relative_position_buckets.view(
|
||||
-1, predict_relative_position_buckets.size(-1)
|
||||
).long() # [ngram*B*head*T, S]
|
||||
).long()
|
||||
|
||||
predict_relative_pos_embeddings = torch.gather(
|
||||
rel_pos_embeddings, dim=1, index=predict_relative_position_buckets
|
||||
).view(
|
||||
self.ngram, batch_size * self.num_attn_heads, sequence_length, -1
|
||||
) # [ngram, B*head, T, S]
|
||||
)
|
||||
|
||||
# [batch_size, gram, num_heads, sequence_length, -1]
|
||||
predict_relative_pos_embeddings = predict_relative_pos_embeddings.view(
|
||||
batch_size, self.ngram, self.num_attn_heads, sequence_length, -1
|
||||
)
|
||||
|
||||
return predict_relative_pos_embeddings
|
||||
|
||||
|
@ -1331,7 +1304,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
|
|||
# prepare attention mask
|
||||
if attention_mask is not None:
|
||||
extended_attention_mask = (
|
||||
1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1)
|
||||
1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1)
|
||||
) * torch.finfo(self.dtype).min
|
||||
extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
|
||||
else:
|
||||
|
@ -1549,7 +1522,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
|||
# prepare encoder attention mask
|
||||
if encoder_attention_mask is not None:
|
||||
extended_encoder_attention_mask = (
|
||||
1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1)
|
||||
1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1)
|
||||
) * torch.finfo(self.dtype).min
|
||||
extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
|
||||
else:
|
||||
|
@ -1717,17 +1690,18 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
|||
device=hidden_states.device,
|
||||
)
|
||||
causal_mask = torch.triu(causal_mask, 1)
|
||||
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand(
|
||||
(batch_size,) + causal_mask.shape
|
||||
|
||||
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand(
|
||||
(batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape
|
||||
)
|
||||
|
||||
# add usual attention mask
|
||||
if attention_mask is not None:
|
||||
extended_attention_mask = (1.0 - attention_mask[:, None, :]) * torch.finfo(self.dtype).min
|
||||
extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min
|
||||
extended_attention_mask = extended_causal_mask + extended_attention_mask
|
||||
else:
|
||||
extended_attention_mask = extended_causal_mask
|
||||
return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype)
|
||||
return extended_attention_mask.to(hidden_states.dtype)
|
||||
|
||||
def prepare_predict_attention_mask(self, hidden_states, attention_mask):
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
@ -1745,14 +1719,16 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
|||
],
|
||||
dim=-1,
|
||||
)
|
||||
extended_predict_causal_mask = predict_causal_mask[:, None, :, :].expand(
|
||||
predict_causal_mask.shape[:1] + (batch_size,) + predict_causal_mask.shape[1:]
|
||||
extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand(
|
||||
(batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape
|
||||
)
|
||||
|
||||
# add usual attention mask
|
||||
if attention_mask is not None:
|
||||
extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * torch.finfo(self.dtype).min
|
||||
extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length))
|
||||
extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min
|
||||
extended_attention_mask = extended_attention_mask.expand(
|
||||
(batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length)
|
||||
)
|
||||
# predicted stream attention_mask should always be 0
|
||||
extended_attention_mask = torch.cat(
|
||||
[extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1
|
||||
|
@ -1760,9 +1736,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
|||
extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask
|
||||
else:
|
||||
extended_predict_attention_mask = extended_predict_causal_mask
|
||||
return extended_predict_attention_mask.repeat(1, self.config.num_decoder_attention_heads, 1, 1).to(
|
||||
hidden_states.dtype
|
||||
)
|
||||
return extended_predict_attention_mask.to(hidden_states.dtype)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
|
|
@ -716,44 +716,27 @@ class XLMProphetNetAttention(nn.Module):
|
|||
past_key_value = (key_states, value_states)
|
||||
|
||||
# project states into the correct shape
|
||||
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
assert attn_weights.size() == (
|
||||
batch_size * self.num_attn_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
), (
|
||||
f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size"
|
||||
f" {attn_weights.shape}"
|
||||
)
|
||||
src_len = key_states.size(2)
|
||||
attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3))
|
||||
expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len)
|
||||
if attn_weights.size() != expected_shape:
|
||||
raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}")
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
|
||||
if attention_mask is not None and attention_mask.dim() == 0:
|
||||
attention_mask = None
|
||||
assert attention_mask is None or attention_mask.size() == (
|
||||
self.num_attn_heads * batch_size,
|
||||
1,
|
||||
src_len,
|
||||
), (
|
||||
"`attention_mask` should be `None` or of shape attention_mask.size() =="
|
||||
f" {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
|
||||
)
|
||||
|
||||
expected_shape = (batch_size, self.num_attn_heads, 1, src_len)
|
||||
if attention_mask is not None and attention_mask.size() != expected_shape:
|
||||
raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}")
|
||||
if attention_mask is not None: # don't attend to padding symbols
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to be reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(batch_size, self.num_attn_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(batch_size * self.num_attn_heads, tgt_len, src_len)
|
||||
attn_weights_reshaped = attn_weights
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
|
@ -767,7 +750,6 @@ class XLMProphetNetAttention(nn.Module):
|
|||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
|
||||
batch_size, self.num_attn_heads, tgt_len, src_len
|
||||
)
|
||||
attn_weights = attn_weights.view(batch_size * self.num_attn_heads, tgt_len, src_len)
|
||||
|
||||
# apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
|
||||
attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped
|
||||
|
@ -777,23 +759,12 @@ class XLMProphetNetAttention(nn.Module):
|
|||
p=self.attention_dropout,
|
||||
training=self.training,
|
||||
)
|
||||
attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states)
|
||||
expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim)
|
||||
if attn_output.size() != expected_shape:
|
||||
raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}")
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
assert attn_output.size() == (
|
||||
batch_size * self.num_attn_heads,
|
||||
tgt_len,
|
||||
self.head_dim,
|
||||
), (
|
||||
f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of"
|
||||
f" shape {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = (
|
||||
attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(batch_size, tgt_len, hidden_size)
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
|
||||
|
@ -873,7 +844,6 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
|
|||
position_ids=None,
|
||||
):
|
||||
batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
|
||||
|
||||
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
|
||||
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
|
||||
f" {hidden_states.shape}"
|
||||
|
@ -891,8 +861,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
|
|||
query_states = self._shape(query_states, ngram_sequence_length, batch_size)
|
||||
key_states = self._shape(key_states, -1, batch_size)
|
||||
value_states = self._shape(value_states, -1, batch_size)
|
||||
|
||||
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
|
||||
|
||||
query_states = query_states.view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
|
@ -900,10 +869,9 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
|
|||
|
||||
# chunk into main stream and predict stream
|
||||
hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
|
||||
|
||||
query_states_list = query_states.chunk(1 + self.ngram, dim=1)
|
||||
key_states_list = key_states.chunk(1 + self.ngram, dim=1)
|
||||
value_states_list = value_states.chunk(1 + self.ngram, dim=1)
|
||||
query_states_list = query_states.chunk(1 + self.ngram, dim=2)
|
||||
key_states_list = key_states.chunk(1 + self.ngram, dim=2)
|
||||
value_states_list = value_states.chunk(1 + self.ngram, dim=2)
|
||||
|
||||
main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:]
|
||||
main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:]
|
||||
|
@ -912,28 +880,29 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
|
|||
|
||||
# saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim)
|
||||
if past_key_value is not None:
|
||||
prev_main_key_states = past_key_value[0].view(batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1)
|
||||
prev_main_value_states = past_key_value[1].view(batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1)
|
||||
prev_main_key_states = past_key_value[0]
|
||||
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2)
|
||||
prev_main_value_states = past_key_value[1]
|
||||
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2)
|
||||
|
||||
# Update cache
|
||||
past_key_value = (
|
||||
main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
|
||||
main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
|
||||
)
|
||||
past_key_value = (main_key_states, main_value_states)
|
||||
|
||||
# get seq_length of main stream only
|
||||
sequence_length = ngram_sequence_length // (1 + self.ngram)
|
||||
|
||||
# MAIN-STREAM
|
||||
# main attn weights
|
||||
main_attn_weights = torch.bmm(main_query_states, main_key_states.transpose(1, 2))
|
||||
# [batch_size, number_heads, sequence_length, head_dimesion]
|
||||
# x [batch_size, number_heads, head_dimesion, sequence_length]
|
||||
# -> [batch_size, number_heads, sequence_length, sequence_length]
|
||||
main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3))
|
||||
|
||||
# retrieve relative position embeddings for each layer -> see paper for more details
|
||||
main_relative_pos_embeddings = self.get_main_relative_pos_embeddings(
|
||||
main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets
|
||||
)
|
||||
|
||||
main_attn_weights = main_attn_weights + main_relative_pos_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
|
@ -953,55 +922,53 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
|
|||
main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
|
||||
batch_size, self.num_attn_heads, -1, sequence_length
|
||||
)
|
||||
main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length)
|
||||
|
||||
main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
|
||||
# project to attn_output
|
||||
main_attn_output = torch.bmm(main_attn_probs, main_value_states)
|
||||
|
||||
# [batch_size, number_heads, sequence_length, sequence_length]
|
||||
# x [batch_size, number_heads, sequence_length, head_dimesion]
|
||||
# -> [batch_size, number_heads, sequence_length, head_dimesion]
|
||||
main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states)
|
||||
# reshape so that num_heads dim is merged into last `head_dim` axis
|
||||
main_attn_output = (
|
||||
main_attn_output.view(batch_size, self.num_attn_heads, sequence_length, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(batch_size, 1, sequence_length, hidden_size)
|
||||
)
|
||||
main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size)
|
||||
main_attn_output = self.out_proj(main_attn_output)
|
||||
|
||||
# PREDICT-STREAM
|
||||
# [ngram, B*head, T, c]
|
||||
predict_query_states = torch.cat(predict_query_states_list, 0).view(
|
||||
self.ngram, -1, sequence_length, self.head_dim
|
||||
)
|
||||
# [ngram, B*head, 2*T, c]
|
||||
predict_key_states = torch.cat(
|
||||
[torch.cat([main_key_states, key], 1).unsqueeze(0) for key in predict_key_states_list], 0
|
||||
# [batch_size, ngram, number_heads, sequence_length, head_dimesion]
|
||||
predict_query_states = torch.stack(predict_query_states_list, 1).view(
|
||||
batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim
|
||||
)
|
||||
|
||||
# [ngram, T, B, C]
|
||||
predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view(
|
||||
self.ngram, sequence_length, batch_size, hidden_size
|
||||
)
|
||||
# [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
|
||||
predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1)
|
||||
|
||||
# [ngram, B*head, 2*T, c]
|
||||
# [batch_size, sequence_length, ngram, hidden_size]
|
||||
predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2)
|
||||
|
||||
# [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion]
|
||||
predict_value_states = torch.cat(
|
||||
[torch.cat([main_value_states, v_p], 1).unsqueeze(0) for v_p in predict_value_states_list], 0
|
||||
[torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2
|
||||
)
|
||||
# [ngram, B*head, T, 2*T]
|
||||
predict_attn_weights = torch.einsum("nbtc,nbsc->nbts", (predict_query_states, predict_key_states))
|
||||
|
||||
# [ngram, B*head, T, S]
|
||||
# [batch_size, ngram, number_heads, sequence_length, head_dimesion]
|
||||
# x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
|
||||
# -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
|
||||
predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states))
|
||||
|
||||
# retrieve relative position embeddings for each layer -> see paper for more details
|
||||
# [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings]
|
||||
predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(
|
||||
predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets
|
||||
)
|
||||
|
||||
# [ngram, B*head, T, 2*T]
|
||||
# [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
|
||||
predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
|
||||
|
||||
if extended_predict_attention_mask is not None:
|
||||
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask.to(
|
||||
predict_attn_weights.dtype
|
||||
)
|
||||
# Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
|
||||
extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4)
|
||||
extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype)
|
||||
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask
|
||||
|
||||
predict_attn_probs = softmax(
|
||||
predict_attn_weights,
|
||||
|
@ -1014,37 +981,30 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
|
|||
f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view(
|
||||
self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length
|
||||
)
|
||||
predict_attn_probs = predict_attn_probs.view(
|
||||
self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length
|
||||
)
|
||||
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs
|
||||
|
||||
predict_attn_probs = nn.functional.dropout(
|
||||
predict_attn_probs, p=self.attention_dropout, training=self.training
|
||||
)
|
||||
# project to attention output
|
||||
# [ngram, B*head, T, c]
|
||||
predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states))
|
||||
# [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
|
||||
# x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
|
||||
# -> [batch_size, ngram, number_heads, sequence_length, head_dimesion]
|
||||
predict_attn_output = torch.einsum(
|
||||
"bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2))
|
||||
)
|
||||
|
||||
# reshape so that num_heads dim is merged into last `head_dim` axis
|
||||
# [ngram, B, T, C]
|
||||
predict_attn_output = (
|
||||
predict_attn_output.view(self.ngram, batch_size, self.num_attn_heads, sequence_length, self.head_dim)
|
||||
.permute(1, 0, 3, 2, 4)
|
||||
.reshape(batch_size, self.ngram, sequence_length, hidden_size)
|
||||
)
|
||||
# [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size]
|
||||
predict_attn_output = predict_attn_output.transpose(2, 3)
|
||||
predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size)
|
||||
predict_attn_output = self.out_proj(predict_attn_output)
|
||||
|
||||
# concat to single attn output
|
||||
# [B, 1+ngram*T, C]
|
||||
# [batch_size, (1+ngram)*sequence_length, hidden_size]
|
||||
attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
|
||||
# reshape into better form for `config.output_attentions`
|
||||
main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
|
||||
predict_attn_probs = predict_attn_probs.view(
|
||||
self.ngram, batch_size, self.num_attn_heads, sequence_length, -1
|
||||
).transpose(0, 1)
|
||||
|
||||
attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
|
||||
|
||||
|
@ -1053,8 +1013,11 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
|
|||
def get_main_relative_pos_embeddings(
|
||||
self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
|
||||
):
|
||||
# input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1]
|
||||
|
||||
# input hidden_states [batch_size, sequence_length, hidden_size]
|
||||
# input attn_weights [batch_size, num_heads, sequence_length, sequence_length]
|
||||
# input position_ids [batch_size, sequence_length] or [1,1]
|
||||
batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape
|
||||
attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len)
|
||||
if main_relative_position_buckets is None:
|
||||
batch_size, sequence_length = hidden_states.shape[:2]
|
||||
relative_positions = (
|
||||
|
@ -1064,39 +1027,42 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
|
|||
.repeat(batch_size, sequence_length, 1)
|
||||
.to(position_ids.device)
|
||||
)
|
||||
relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(
|
||||
batch_size, sequence_length, 1
|
||||
) # [B, T, s]
|
||||
# [batch_size, sequence_length, sequence_length+1]
|
||||
relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
|
||||
main_relative_position_buckets = compute_relative_buckets(
|
||||
self.num_buckets, self.relative_max_distance, relative_positions, False
|
||||
)
|
||||
|
||||
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head]
|
||||
# [batch_size, sequence_length, num_buckets * num_heads]
|
||||
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
|
||||
rel_pos_embeddings = rel_pos_embeddings.view(
|
||||
rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
|
||||
).permute(
|
||||
0, 3, 1, 2
|
||||
) # [B,T,Buckets,head]
|
||||
rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:2] + (-1,)) # [B*head,T,Buckets]
|
||||
)
|
||||
rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2)
|
||||
# [batch_size, num_heads, sequence_length, num_buckets]
|
||||
rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,))
|
||||
|
||||
main_relative_position_buckets = (
|
||||
main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
|
||||
.view(-1, main_relative_position_buckets.shape[-1])
|
||||
.long()
|
||||
) # [B*head*T, T]
|
||||
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) # [B*head*T,Buckets]
|
||||
|
||||
main_relative_pos_embeddings = torch.gather(
|
||||
rel_pos_embeddings, dim=1, index=main_relative_position_buckets
|
||||
).view(attn_weights.shape[:2] + (-1,))
|
||||
main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
|
||||
# [batch_size * num_heads * sequence_length, sequence_length]
|
||||
main_relative_position_buckets = main_relative_position_buckets.view(
|
||||
-1, main_relative_position_buckets.shape[-1]
|
||||
)
|
||||
main_relative_position_buckets = main_relative_position_buckets.long()
|
||||
# [batch_size * num_heads * sequence_length, sequence_length]
|
||||
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
|
||||
|
||||
main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets)
|
||||
main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1)
|
||||
return main_relative_pos_embeddings
|
||||
|
||||
def get_predict_relative_pos_embeddings(
|
||||
self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
|
||||
):
|
||||
# input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None
|
||||
sequence_length, batch_size = hidden_states.shape[1:3]
|
||||
# input hidden_states [batch_size, sequence_length, ngram, hidden_size]
|
||||
# input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length]
|
||||
# input position_ids [batch_size, sequence_length] or [1,1]
|
||||
# input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None
|
||||
batch_size, sequence_length = hidden_states.shape[0:2]
|
||||
|
||||
if predict_relative_position_buckets is None:
|
||||
key_sequence_length = attn_weights.shape[-1]
|
||||
|
@ -1116,28 +1082,35 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
|
|||
self.num_buckets, self.relative_max_distance, relative_positions, False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2) # [ngram, B, T, C]
|
||||
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states).view(
|
||||
# [batch_size, ngram, sequence_length, hidden_size]
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
|
||||
|
||||
# [batch_size, ngram, sequence_length, num_buckets, num_heads]
|
||||
rel_pos_embeddings = rel_pos_embeddings.view(
|
||||
hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads)
|
||||
) # [ngram, B, T, bucket, head]
|
||||
rel_pos_embeddings = rel_pos_embeddings.permute(0, 1, 4, 2, 3).reshape(
|
||||
self.ngram * batch_size * self.num_attn_heads, sequence_length, -1
|
||||
) # [ngram*B*head, T, bucket]
|
||||
|
||||
predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0).repeat(
|
||||
)
|
||||
rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3)
|
||||
# [batch_size * ngram * sequence_length * num_heads, num_buckets]
|
||||
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets)
|
||||
# [ngram, batch_size, num_heads * sequence_length, -1]
|
||||
predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0)
|
||||
predict_relative_position_buckets = predict_relative_position_buckets.repeat(
|
||||
self.ngram, 1, self.num_attn_heads, 1
|
||||
) # [ngram, B, head*T, S]
|
||||
|
||||
rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
|
||||
)
|
||||
# [ngram * batch_size * num_heads * sequence_length, -1]
|
||||
predict_relative_position_buckets = predict_relative_position_buckets.view(
|
||||
-1, predict_relative_position_buckets.size(-1)
|
||||
).long() # [ngram*B*head*T, S]
|
||||
).long()
|
||||
|
||||
predict_relative_pos_embeddings = torch.gather(
|
||||
rel_pos_embeddings, dim=1, index=predict_relative_position_buckets
|
||||
).view(
|
||||
self.ngram, batch_size * self.num_attn_heads, sequence_length, -1
|
||||
) # [ngram, B*head, T, S]
|
||||
)
|
||||
|
||||
# [batch_size, gram, num_heads, sequence_length, -1]
|
||||
predict_relative_pos_embeddings = predict_relative_pos_embeddings.view(
|
||||
batch_size, self.ngram, self.num_attn_heads, sequence_length, -1
|
||||
)
|
||||
|
||||
return predict_relative_pos_embeddings
|
||||
|
||||
|
@ -1351,7 +1324,7 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel):
|
|||
# prepare attention mask
|
||||
if attention_mask is not None:
|
||||
extended_attention_mask = (
|
||||
1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1)
|
||||
1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1)
|
||||
) * torch.finfo(self.dtype).min
|
||||
extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
|
||||
else:
|
||||
|
@ -1572,7 +1545,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
|
|||
# prepare encoder attention mask
|
||||
if encoder_attention_mask is not None:
|
||||
extended_encoder_attention_mask = (
|
||||
1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1)
|
||||
1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1)
|
||||
) * torch.finfo(self.dtype).min
|
||||
extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
|
||||
else:
|
||||
|
@ -1740,17 +1713,18 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
|
|||
device=hidden_states.device,
|
||||
)
|
||||
causal_mask = torch.triu(causal_mask, 1)
|
||||
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand(
|
||||
(batch_size,) + causal_mask.shape
|
||||
|
||||
extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand(
|
||||
(batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape
|
||||
)
|
||||
|
||||
# add usual attention mask
|
||||
if attention_mask is not None:
|
||||
extended_attention_mask = (1.0 - attention_mask[:, None, :]) * torch.finfo(self.dtype).min
|
||||
extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min
|
||||
extended_attention_mask = extended_causal_mask + extended_attention_mask
|
||||
else:
|
||||
extended_attention_mask = extended_causal_mask
|
||||
return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype)
|
||||
return extended_attention_mask.to(hidden_states.dtype)
|
||||
|
||||
def prepare_predict_attention_mask(self, hidden_states, attention_mask):
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
@ -1768,14 +1742,16 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
|
|||
],
|
||||
dim=-1,
|
||||
)
|
||||
extended_predict_causal_mask = predict_causal_mask[:, None, :, :].expand(
|
||||
predict_causal_mask.shape[:1] + (batch_size,) + predict_causal_mask.shape[1:]
|
||||
extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand(
|
||||
(batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape
|
||||
)
|
||||
|
||||
# add usual attention mask
|
||||
if attention_mask is not None:
|
||||
extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * torch.finfo(self.dtype).min
|
||||
extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length))
|
||||
extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min
|
||||
extended_attention_mask = extended_attention_mask.expand(
|
||||
(batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length)
|
||||
)
|
||||
# predicted stream attention_mask should always be 0
|
||||
extended_attention_mask = torch.cat(
|
||||
[extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1
|
||||
|
@ -1783,9 +1759,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
|
|||
extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask
|
||||
else:
|
||||
extended_predict_attention_mask = extended_predict_causal_mask
|
||||
return extended_predict_attention_mask.repeat(1, self.config.num_decoder_attention_heads, 1, 1).to(
|
||||
hidden_states.dtype
|
||||
)
|
||||
return extended_predict_attention_mask.to(hidden_states.dtype)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
|
|
@ -1206,7 +1206,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
|
|||
expected_shape = torch.Size((1, 12, 30522))
|
||||
self.assertEqual(output_predited_logits.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
[[[-7.6213, -7.9008, -7.9979], [-7.6834, -7.8467, -8.2187], [-7.5326, -7.4762, -8.1914]]]
|
||||
[[[-7.7729, -8.0343, -8.26001], [-7.74213, -7.8629, -8.6000], [-7.7328, -7.8269, -8.5264]]]
|
||||
).to(torch_device)
|
||||
# self.assertTrue(torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4))
|
||||
assert torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4)
|
||||
|
@ -1306,7 +1306,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
|
|||
EXPECTED_QUESTIONS = [
|
||||
"along with paul allen, who founded microsoft?",
|
||||
"what year was microsoft founded?",
|
||||
"on what date was microsoft founded?",
|
||||
"when was microsoft founded?",
|
||||
]
|
||||
|
||||
self.assertListEqual(
|
||||
|
|
Loading…
Reference in New Issue