[ProphetNet] Bart-like Refactor (#10501)
* first step to refactor * make all fast tests pass * make all slow tests pass * save intermediate * correct cache * finish PR * make fp16 work
This commit is contained in:
parent
6290169eb3
commit
c503a1c15e
|
@ -92,6 +92,8 @@ class ProphetNetConfig(PretrainedConfig):
|
|||
smoothing is performed.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
"""
|
||||
model_type = "prophetnet"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
@ -119,6 +121,7 @@ class ProphetNetConfig(PretrainedConfig):
|
|||
num_buckets=32,
|
||||
relative_max_distance=128,
|
||||
disable_ngram_loss=False,
|
||||
gradient_checkpointing=False,
|
||||
eps=0.0,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
|
@ -161,6 +164,9 @@ class ProphetNetConfig(PretrainedConfig):
|
|||
|
||||
self.use_cache = use_cache
|
||||
|
||||
# 4 Training Args (should be removed soon)
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.num_encoder_attention_heads
|
||||
|
|
|
@ -18,7 +18,7 @@ import copy
|
|||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -567,6 +567,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
|
|||
"""
|
||||
|
||||
def __init__(self, config: ProphetNetConfig):
|
||||
self.max_length = config.max_position_embeddings
|
||||
super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)
|
||||
|
||||
def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None):
|
||||
|
@ -578,7 +579,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
|
|||
if past_key_values is not None:
|
||||
# position_ids is the same for every token when decoding a single step
|
||||
# Without the int() cast, it doesn't work in some cases when exporting to ONNX
|
||||
prev_num_input_ids = past_key_values[0]["self"]["prev_key_states"].shape[2]
|
||||
prev_num_input_ids = past_key_values[0][0].shape[2]
|
||||
num_input_ids = inputs_shape[1] + prev_num_input_ids
|
||||
position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * (
|
||||
int(self.padding_idx + num_input_ids)
|
||||
|
@ -592,6 +593,9 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
|
|||
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
|
||||
).long() + self.padding_idx
|
||||
|
||||
# make sure position_ids are not bigger then max_length
|
||||
position_ids = position_ids.clamp(0, self.max_length - 1)
|
||||
|
||||
return super().forward(position_ids), position_ids
|
||||
|
||||
def _forward(self, position_ids):
|
||||
|
@ -624,66 +628,65 @@ class ProphetNetAttention(nn.Module):
|
|||
|
||||
self.out_proj = nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
def _reshape(self, tensor, first_dim, batch_size):
|
||||
return tensor.reshape(first_dim, batch_size * self.num_attn_heads, self.head_dim).transpose(0, 1)
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
key_value_states: Optional[Tensor] = None,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
|
||||
past_key_value: Optional[Tuple[Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
|
||||
sequence_length, batch_size, hidden_size = hidden_states.size()
|
||||
batch_size, tgt_len, hidden_size = hidden_states.size()
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
cache_key = "cross_attention" if is_cross_attention else "self"
|
||||
assert list(hidden_states.size()) == [
|
||||
sequence_length,
|
||||
batch_size,
|
||||
tgt_len,
|
||||
hidden_size,
|
||||
], f"Size of hidden states should be {sequence_length, batch_size, hidden_size}, but is {hidden_states.size()}"
|
||||
], f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.size()}"
|
||||
|
||||
# previous time steps are cached - no need to recompute key and value if they are static
|
||||
if layer_state is not None:
|
||||
saved_state = layer_state.get(cache_key, None)
|
||||
|
||||
query_states = self.query_proj(hidden_states) / (self.head_dim ** 0.5)
|
||||
query_states = self._reshape(query_states, sequence_length, batch_size)
|
||||
|
||||
if not is_cross_attention:
|
||||
# self-attention
|
||||
key_states = self.key_proj(hidden_states)
|
||||
key_states = self._reshape(key_states, -1, batch_size)
|
||||
value_states = self.value_proj(hidden_states)
|
||||
value_states = self._reshape(value_states, -1, batch_size)
|
||||
elif saved_state is None:
|
||||
# cross-attention without layer state
|
||||
key_states = self.key_proj(key_value_states)
|
||||
key_states = self._reshape(key_states, -1, batch_size)
|
||||
value_states = self.value_proj(key_value_states)
|
||||
value_states = self._reshape(value_states, -1, batch_size)
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.key_proj(key_value_states), -1, batch_size)
|
||||
value_states = self._shape(self.value_proj(key_value_states), -1, batch_size)
|
||||
else:
|
||||
key_states = saved_state["prev_key_states"].view(batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
value_states = saved_state["prev_value_states"].view(batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
# self_attention
|
||||
key_states = self._shape(self.key_proj(hidden_states), -1, batch_size)
|
||||
value_states = self._shape(self.value_proj(hidden_states), -1, batch_size)
|
||||
|
||||
# Update cache
|
||||
if is_cross_attention:
|
||||
layer_state[cache_key] = {
|
||||
"prev_key_states": key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
|
||||
"prev_value_states": value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
|
||||
}
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
key_sequence_length = key_states.size(1)
|
||||
# project states into the correct shape
|
||||
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
assert attn_weights.size() == (
|
||||
batch_size * self.num_attn_heads,
|
||||
sequence_length,
|
||||
key_sequence_length,
|
||||
), f"`attn_weights` should be of size {batch_size * self.num_attn_heads, sequence_length, key_sequence_length}, but is of size {attn_weights.shape}"
|
||||
tgt_len,
|
||||
src_len,
|
||||
), f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size {attn_weights.shape}"
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
|
||||
if attention_mask is not None and attention_mask.dim() == 0:
|
||||
|
@ -691,19 +694,21 @@ class ProphetNetAttention(nn.Module):
|
|||
assert attention_mask is None or attention_mask.size() == (
|
||||
self.num_attn_heads * batch_size,
|
||||
1,
|
||||
key_sequence_length,
|
||||
), f"`attention_mask` should be `None` or of shape attention_mask.size() == {batch_size * self.num_attn_heads, 1, key_sequence_length}, but is {attention_mask.shape}"
|
||||
src_len,
|
||||
), f"`attention_mask` should be `None` or of shape attention_mask.size() == {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
|
||||
|
||||
if attention_mask is not None: # don't attend to padding symbols
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# need two reshapes to keep gradient at attention weights
|
||||
attn_weights_reshaped = attn_weights.view(
|
||||
batch_size, self.num_attn_heads, sequence_length, key_sequence_length
|
||||
)
|
||||
attn_weights = attn_weights_reshaped.view(
|
||||
batch_size * self.num_attn_heads, sequence_length, key_sequence_length
|
||||
)
|
||||
if output_attentions:
|
||||
# this operation is a bit akward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(batch_size, self.num_attn_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(batch_size * self.num_attn_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
attn_probs = F.dropout(
|
||||
|
@ -715,15 +720,20 @@ class ProphetNetAttention(nn.Module):
|
|||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
assert attn_output.size() == (
|
||||
batch_size * self.num_attn_heads,
|
||||
sequence_length,
|
||||
tgt_len,
|
||||
self.head_dim,
|
||||
), "`attn_output` should be of shape {batch_size * self.num_attn_heads, sequence_length, self.head_dim}, but is of shape {attn_output.size()}"
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(sequence_length, batch_size, hidden_size)
|
||||
), "`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of shape {attn_output.size()}"
|
||||
|
||||
attn_output = (
|
||||
attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(batch_size, tgt_len, hidden_size)
|
||||
)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
attn_output = F.dropout(attn_output, p=self.dropout, training=self.training)
|
||||
return attn_output, attn_weights_reshaped
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
class ProphetNetFeedForward(nn.Module):
|
||||
|
@ -779,8 +789,8 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
# for onnx runtime
|
||||
self.onnx_trace = False
|
||||
|
||||
def _reshape(self, tensor, first_dim, batch_size):
|
||||
return tensor.reshape(first_dim, batch_size * self.num_attn_heads, self.head_dim).transpose(0, 1)
|
||||
def _shape(self, tensor, seq_len, batch_size):
|
||||
return tensor.view(batch_size, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def prepare_for_onnx_export_(self):
|
||||
self.onnx_trace = True
|
||||
|
@ -788,23 +798,20 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
layer_state=None,
|
||||
past_key_value: Optional[Tuple[Tensor]] = None,
|
||||
attention_mask=None,
|
||||
extended_predict_attention_mask=None,
|
||||
main_relative_position_buckets=None,
|
||||
predict_relative_position_buckets=None,
|
||||
position_ids=None,
|
||||
):
|
||||
sequence_length, batch_size, hidden_size = hidden_states.size()
|
||||
batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
|
||||
|
||||
assert list(hidden_states.size()) == [
|
||||
sequence_length,
|
||||
batch_size,
|
||||
ngram_sequence_length,
|
||||
hidden_size,
|
||||
], f"`hidden_states` should be of shape {sequence_length, batch_size, hidden_size}, but is of shape {hidden_states.shape}"
|
||||
|
||||
# key and value of previous time steps are cached
|
||||
saved_state = layer_state.get("self", None)
|
||||
], f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape {hidden_states.shape}"
|
||||
|
||||
# project
|
||||
query_states = self.query_proj(hidden_states)
|
||||
|
@ -815,12 +822,18 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
query_states = query_states / (self.head_dim ** 0.5)
|
||||
|
||||
# reshape
|
||||
query_states = self._reshape(query_states, sequence_length, batch_size)
|
||||
key_states = self._reshape(key_states, -1, batch_size)
|
||||
value_states = self._reshape(value_states, -1, batch_size)
|
||||
query_states = self._shape(query_states, ngram_sequence_length, batch_size)
|
||||
key_states = self._shape(key_states, -1, batch_size)
|
||||
value_states = self._shape(value_states, -1, batch_size)
|
||||
|
||||
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
|
||||
query_states = query_states.view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
# chunk into main stream and predict stream
|
||||
hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=0)
|
||||
hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
|
||||
|
||||
query_states_list = query_states.chunk(1 + self.ngram, dim=1)
|
||||
key_states_list = key_states.chunk(1 + self.ngram, dim=1)
|
||||
|
@ -832,24 +845,20 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:]
|
||||
|
||||
# saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim)
|
||||
if saved_state is not None:
|
||||
prev_main_key_states = saved_state["prev_key_states"].view(
|
||||
batch_size * self.num_attn_heads, -1, self.head_dim
|
||||
)
|
||||
if past_key_value is not None:
|
||||
prev_main_key_states = past_key_value[0].view(batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1)
|
||||
prev_main_value_states = saved_state["prev_value_states"].view(
|
||||
batch_size * self.num_attn_heads, -1, self.head_dim
|
||||
)
|
||||
prev_main_value_states = past_key_value[1].view(batch_size * self.num_attn_heads, -1, self.head_dim)
|
||||
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1)
|
||||
|
||||
# Update cache
|
||||
layer_state["self"] = {
|
||||
"prev_key_states": main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
|
||||
"prev_value_states": main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
|
||||
}
|
||||
past_key_value = (
|
||||
main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
|
||||
main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
|
||||
)
|
||||
|
||||
# get seq_length of main stream only
|
||||
main_sequence_length = sequence_length // (1 + self.ngram)
|
||||
sequence_length = ngram_sequence_length // (1 + self.ngram)
|
||||
|
||||
# MAIN-STREAM
|
||||
# main attn weights
|
||||
|
@ -871,18 +880,21 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
).type_as(main_attn_weights)
|
||||
|
||||
main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
|
||||
|
||||
# project to attn_output
|
||||
main_attn_output = torch.bmm(main_attn_probs, main_value_states)
|
||||
|
||||
# reshape so that num_heads dim is merged into last `head_dim` axis
|
||||
main_attn_output = (
|
||||
main_attn_output.transpose(0, 1).contiguous().view(1, main_sequence_length, batch_size, hidden_size)
|
||||
main_attn_output.view(batch_size, self.num_attn_heads, sequence_length, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(batch_size, 1, sequence_length, hidden_size)
|
||||
)
|
||||
main_attn_output = self.out_proj(main_attn_output)
|
||||
|
||||
# PREDICT-STREAM
|
||||
# [ngram, B*head, T, c]
|
||||
predict_query_states = torch.cat(predict_query_states_list, 0).view(
|
||||
self.ngram, -1, main_sequence_length, self.head_dim
|
||||
self.ngram, -1, sequence_length, self.head_dim
|
||||
)
|
||||
# [ngram, B*head, 2*T, c]
|
||||
predict_key_states = torch.cat(
|
||||
|
@ -891,7 +903,7 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
|
||||
# [ngram, T, B, C]
|
||||
predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view(
|
||||
self.ngram, main_sequence_length, batch_size, hidden_size
|
||||
self.ngram, sequence_length, batch_size, hidden_size
|
||||
)
|
||||
|
||||
# [ngram, B*head, 2*T, c]
|
||||
|
@ -911,7 +923,9 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
|
||||
|
||||
if extended_predict_attention_mask is not None:
|
||||
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask
|
||||
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask.to(
|
||||
predict_attn_weights.dtype
|
||||
)
|
||||
|
||||
predict_attn_probs = softmax(
|
||||
predict_attn_weights,
|
||||
|
@ -919,35 +933,36 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
onnx_trace=self.onnx_trace,
|
||||
).type_as(predict_attn_weights)
|
||||
predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training)
|
||||
|
||||
# project to attention output
|
||||
# [ngram, B*head, T, c]
|
||||
predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states))
|
||||
# [ngram, T, B, C]
|
||||
|
||||
# reshape so that num_heads dim is merged into last `head_dim` axis
|
||||
# [ngram, B, T, C]
|
||||
predict_attn_output = (
|
||||
predict_attn_output.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(self.ngram, main_sequence_length, batch_size, hidden_size)
|
||||
predict_attn_output.view(self.ngram, batch_size, self.num_attn_heads, sequence_length, self.head_dim)
|
||||
.permute(1, 0, 3, 2, 4)
|
||||
.reshape(batch_size, self.ngram, sequence_length, hidden_size)
|
||||
)
|
||||
predict_attn_output = self.out_proj(predict_attn_output)
|
||||
|
||||
# concat to single attn output
|
||||
# [1+ngram*T, B, C]
|
||||
attn_output = torch.cat([main_attn_output, predict_attn_output], 0).view(-1, batch_size, hidden_size)
|
||||
|
||||
# [B, 1+ngram*T, C]
|
||||
attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
|
||||
# reshape into better form for `config.output_attentions`
|
||||
main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, main_sequence_length, -1)
|
||||
main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
|
||||
predict_attn_probs = predict_attn_probs.view(
|
||||
self.ngram, batch_size, self.num_attn_heads, main_sequence_length, -1
|
||||
self.ngram, batch_size, self.num_attn_heads, sequence_length, -1
|
||||
).transpose(0, 1)
|
||||
|
||||
attn_output = F.dropout(attn_output, p=self.dropout, training=self.training)
|
||||
return attn_output, main_attn_probs, predict_attn_probs
|
||||
|
||||
return attn_output, main_attn_probs, predict_attn_probs, past_key_value
|
||||
|
||||
def get_main_relative_pos_embeddings(
|
||||
self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
|
||||
):
|
||||
# input hidden_states [T,B,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1]
|
||||
# input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1]
|
||||
|
||||
if main_relative_position_buckets is None:
|
||||
batch_size, sequence_length = hidden_states.shape[:2]
|
||||
|
@ -965,7 +980,6 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
self.num_buckets, self.relative_max_distance, relative_positions, False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(0, 1) # [B,T,C]
|
||||
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head]
|
||||
rel_pos_embeddings = rel_pos_embeddings.view(
|
||||
rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
|
||||
|
@ -991,7 +1005,6 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
|||
self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
|
||||
):
|
||||
# input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None
|
||||
|
||||
sequence_length, batch_size = hidden_states.shape[1:3]
|
||||
|
||||
if predict_relative_position_buckets is None:
|
||||
|
@ -1053,18 +1066,25 @@ class ProphetNetEncoderLayer(nn.Module):
|
|||
self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim)
|
||||
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
def forward(self, hidden_states, attention_mask, output_attentions: bool = False):
|
||||
# 1st residual block
|
||||
attention_output, attn_weights = self.self_attn(
|
||||
attention_output, attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = self.self_attn_layer_norm(attention_output + hidden_states)
|
||||
|
||||
# 2nd residual block
|
||||
feed_forward_output = self.feed_forward(hidden_states)
|
||||
hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
|
||||
return hidden_states, attn_weights
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class ProphetNetDecoderLayer(nn.Module):
|
||||
|
@ -1090,21 +1110,23 @@ class ProphetNetDecoderLayer(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attn_mask=None,
|
||||
layer_state=None,
|
||||
attention_mask=None,
|
||||
extended_predict_attention_mask=None,
|
||||
main_relative_position_buckets=None,
|
||||
predict_relative_position_buckets=None,
|
||||
position_ids=None,
|
||||
past_key_value=None,
|
||||
use_cache: bool = True,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
layer_state = layer_state if layer_state is not None else {}
|
||||
|
||||
# 1st residual block
|
||||
ngram_attention_output, self_attn_weights, self_attn_weights_ngram = self.self_attn(
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
layer_state=layer_state,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
extended_predict_attention_mask=extended_predict_attention_mask,
|
||||
main_relative_position_buckets=main_relative_position_buckets,
|
||||
|
@ -1113,28 +1135,36 @@ class ProphetNetDecoderLayer(nn.Module):
|
|||
)
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attn_weights = None
|
||||
if encoder_hidden_states is not None:
|
||||
# 2nd residual block
|
||||
attention_output, cross_attn_weights = self.cross_attn(
|
||||
attention_output, cross_attn_weights, cross_attn_present_key_value = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attn_mask,
|
||||
layer_state=layer_state, # mutates layer state
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states)
|
||||
|
||||
# add cross-attn to positions 3,4 of present_key_value tuple
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
# 3rd residual block
|
||||
feed_forward_output = self.feed_forward(hidden_states)
|
||||
hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
self_attn_weights,
|
||||
self_attn_weights_ngram,
|
||||
cross_attn_weights,
|
||||
layer_state,
|
||||
) # just self_attn weights for now, following t5, layer_state = cache for decoding
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -1223,21 +1253,37 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
|
|||
hidden_states = inputs_embeds + position_embeddings
|
||||
hidden_states = self.embeddings_layer_norm(hidden_states)
|
||||
hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
|
||||
hidden_states = hidden_states.transpose(0, 1) # B x T x C -> T x B x C
|
||||
|
||||
encoder_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
encoder_hidden_states = encoder_hidden_states + (hidden_states,)
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
hidden_states, attn_probs = encoder_layer(hidden_states, attention_mask=extended_attention_mask)
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (attn_probs,)
|
||||
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(encoder_layer),
|
||||
hidden_states,
|
||||
extended_attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states, attention_mask=extended_attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_hidden_states = encoder_hidden_states + (hidden_states,)
|
||||
|
||||
|
@ -1370,26 +1416,24 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
|||
|
||||
# add position embeddings
|
||||
hidden_states = inputs_embeds + main_stream_pos_embed
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
|
||||
ngram_embeddings = self.ngram_embeddings.weight
|
||||
|
||||
# prepare attention mask
|
||||
if past_key_values is not None:
|
||||
assert (
|
||||
hidden_states.size(0) == 1
|
||||
hidden_states.size(1) == 1
|
||||
), "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1"
|
||||
|
||||
ngram_hidden_states = [
|
||||
(ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).transpose(0, 1).repeat(1, batch_size, 1)
|
||||
(ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).repeat(batch_size, 1, 1)
|
||||
for ngram in range(self.ngram)
|
||||
]
|
||||
extended_attention_mask = None
|
||||
extended_predict_attention_mask = None
|
||||
else:
|
||||
ngram_hidden_states = [
|
||||
(ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).transpose(0, 1)
|
||||
for ngram in range(self.ngram)
|
||||
(ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) for ngram in range(self.ngram)
|
||||
]
|
||||
extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask)
|
||||
extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask)
|
||||
|
@ -1403,16 +1447,13 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
|||
else:
|
||||
extended_encoder_attention_mask = None
|
||||
|
||||
hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 0)
|
||||
hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 1)
|
||||
|
||||
if self.embeddings_layer_norm:
|
||||
hidden_states = self.embeddings_layer_norm(hidden_states)
|
||||
|
||||
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||
|
||||
# init attentions, hidden_states and cache with empty tuples
|
||||
all_main_stream_hidden_states = () if output_hidden_states else None
|
||||
all_ngram_stream_hidden_states = () if output_hidden_states and self.config.ngram > 0 else None
|
||||
|
@ -1425,47 +1466,75 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
|||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
# grad cannot be kept because tensor is sliced
|
||||
all_main_stream_hidden_states += (hidden_states[:sequence_length].transpose(0, 1),)
|
||||
all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
|
||||
if self.config.ngram > 0:
|
||||
all_ngram_stream_hidden_states += (hidden_states[sequence_length:].transpose(0, 1),)
|
||||
all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warn(
|
||||
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
|
||||
"`use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
extended_attention_mask,
|
||||
encoder_hidden_states,
|
||||
extended_encoder_attention_mask,
|
||||
extended_predict_attention_mask,
|
||||
main_relative_position_buckets,
|
||||
predict_relative_position_buckets,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attn_mask=extended_encoder_attention_mask,
|
||||
extended_predict_attention_mask=extended_predict_attention_mask,
|
||||
main_relative_position_buckets=main_relative_position_buckets,
|
||||
predict_relative_position_buckets=predict_relative_position_buckets,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
layer_state = past_key_values[idx] if past_key_values is not None else None
|
||||
(
|
||||
hidden_states,
|
||||
layer_self_attn,
|
||||
layer_self_predict_attn_output,
|
||||
layer_cross_attn,
|
||||
layer_past,
|
||||
) = decoder_layer(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attn_mask=extended_encoder_attention_mask,
|
||||
layer_state=layer_state,
|
||||
attention_mask=extended_attention_mask,
|
||||
extended_predict_attention_mask=extended_predict_attention_mask,
|
||||
main_relative_position_buckets=main_relative_position_buckets,
|
||||
predict_relative_position_buckets=predict_relative_position_buckets,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
if use_cache:
|
||||
present_key_values += (layer_past,)
|
||||
present_key_values += (layer_outputs[4 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_main_stream_attns += (layer_self_attn,)
|
||||
all_ngram_stream_attns += (layer_self_predict_attn_output,)
|
||||
all_main_stream_attns += (layer_outputs[1],)
|
||||
all_ngram_stream_attns += (layer_outputs[2],)
|
||||
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attns += (layer_cross_attn,)
|
||||
all_cross_attns += (layer_outputs[3],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_main_stream_hidden_states += (hidden_states[:sequence_length].transpose(0, 1),)
|
||||
all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
|
||||
if self.config.ngram > 0:
|
||||
all_ngram_stream_hidden_states += (hidden_states[sequence_length:].transpose(0, 1),)
|
||||
all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
|
||||
|
||||
# split last_hidden_state for return
|
||||
last_hidden_state = hidden_states[:sequence_length].transpose(0, 1)
|
||||
last_hidden_state_ngram = hidden_states[sequence_length:].transpose(0, 1) if self.config.ngram > 0 else None
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(0, 1) if encoder_hidden_states is not None else None
|
||||
last_hidden_state = hidden_states[:, :sequence_length]
|
||||
last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
|
@ -1516,7 +1585,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
|||
return main_relative_buckets, predict_relative_buckets
|
||||
|
||||
def prepare_attention_mask(self, hidden_states, attention_mask):
|
||||
seq_length, batch_size = hidden_states.shape[:2]
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
# get causal mask
|
||||
causal_mask = hidden_states.new(seq_length, seq_length).float().fill_(-float("inf"))
|
||||
|
@ -1534,7 +1603,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
|||
return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype)
|
||||
|
||||
def prepare_predict_attention_mask(self, hidden_states, attention_mask):
|
||||
seq_length, batch_size = hidden_states.shape[:2]
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
# get causal mask
|
||||
predict_causal_mask = ngram_attention_bias(
|
||||
|
@ -1656,7 +1725,7 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
|
|||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
# decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
|
@ -1856,21 +1925,14 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
|||
return self._shift_right(labels)
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache
|
||||
def _reorder_cache(past, beam_idx):
|
||||
# this function reorders the cache for beam search
|
||||
def _reorder_cache(cache_dict, beam_idx):
|
||||
for k, key_value_states in cache_dict.items():
|
||||
if key_value_states is not None:
|
||||
cache_dict[k] = key_value_states.index_select(0, beam_idx)
|
||||
return cache_dict
|
||||
|
||||
reordered_past = []
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||
layer_past_new = {
|
||||
attn_key: _reorder_cache(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||||
}
|
||||
reordered_past.append(layer_past_new)
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
def get_encoder(self):
|
||||
|
@ -1995,7 +2057,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
|||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
# decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
|
||||
outputs = self.prophetnet.decoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
@ -2080,21 +2142,11 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
|||
}
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache
|
||||
def _reorder_cache(past, beam_idx):
|
||||
# this function reorders the cache for beam search
|
||||
def _reorder_cache(cache_dict, beam_idx):
|
||||
for k, key_value_states in cache_dict.items():
|
||||
if key_value_states is not None:
|
||||
cache_dict[k] = key_value_states.index_select(0, beam_idx)
|
||||
return cache_dict
|
||||
|
||||
reordered_past = []
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||
layer_past_new = {
|
||||
attn_key: _reorder_cache(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||||
}
|
||||
reordered_past.append(layer_past_new)
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
|
|
@ -243,7 +243,7 @@ class ProphetNetModelTester:
|
|||
# There should be `num_layers` key value embeddings stored in decoder_past
|
||||
self.parent.assertEqual(len(decoder_past), config.num_decoder_layers)
|
||||
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
|
||||
self.parent.assertEqual(len(decoder_past[0]), 2) # cross-attention + uni-directional self-attention
|
||||
self.parent.assertEqual(len(decoder_past[0]), 4) # cross-attention + uni-directional self-attention
|
||||
|
||||
def create_and_check_with_lm_head(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue