[ProphetNet] Bart-like Refactor (#10501)

* first step to refactor

* make all fast tests pass

* make all slow tests pass

* save intermediate

* correct cache

* finish PR

* make fp16 work
This commit is contained in:
Patrick von Platen 2021-03-04 23:27:12 +03:00 committed by GitHub
parent 6290169eb3
commit c503a1c15e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 245 additions and 187 deletions

View File

@ -92,6 +92,8 @@ class ProphetNetConfig(PretrainedConfig):
smoothing is performed.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
"""
model_type = "prophetnet"
keys_to_ignore_at_inference = ["past_key_values"]
@ -119,6 +121,7 @@ class ProphetNetConfig(PretrainedConfig):
num_buckets=32,
relative_max_distance=128,
disable_ngram_loss=False,
gradient_checkpointing=False,
eps=0.0,
use_cache=True,
pad_token_id=0,
@ -161,6 +164,9 @@ class ProphetNetConfig(PretrainedConfig):
self.use_cache = use_cache
# 4 Training Args (should be removed soon)
self.gradient_checkpointing = gradient_checkpointing
@property
def num_attention_heads(self) -> int:
return self.num_encoder_attention_heads

View File

@ -18,7 +18,7 @@ import copy
import math
import warnings
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
@ -567,6 +567,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
"""
def __init__(self, config: ProphetNetConfig):
self.max_length = config.max_position_embeddings
super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)
def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None):
@ -578,7 +579,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
if past_key_values is not None:
# position_ids is the same for every token when decoding a single step
# Without the int() cast, it doesn't work in some cases when exporting to ONNX
prev_num_input_ids = past_key_values[0]["self"]["prev_key_states"].shape[2]
prev_num_input_ids = past_key_values[0][0].shape[2]
num_input_ids = inputs_shape[1] + prev_num_input_ids
position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * (
int(self.padding_idx + num_input_ids)
@ -592,6 +593,9 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
).long() + self.padding_idx
# make sure position_ids are not bigger then max_length
position_ids = position_ids.clamp(0, self.max_length - 1)
return super().forward(position_ids), position_ids
def _forward(self, position_ids):
@ -624,66 +628,65 @@ class ProphetNetAttention(nn.Module):
self.out_proj = nn.Linear(hidden_size, hidden_size)
def _reshape(self, tensor, first_dim, batch_size):
return tensor.reshape(first_dim, batch_size * self.num_attn_heads, self.head_dim).transpose(0, 1)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states,
key_value_states: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
past_key_value: Optional[Tuple[Tensor]] = None,
output_attentions: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
sequence_length, batch_size, hidden_size = hidden_states.size()
batch_size, tgt_len, hidden_size = hidden_states.size()
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
cache_key = "cross_attention" if is_cross_attention else "self"
assert list(hidden_states.size()) == [
sequence_length,
batch_size,
tgt_len,
hidden_size,
], f"Size of hidden states should be {sequence_length, batch_size, hidden_size}, but is {hidden_states.size()}"
], f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.size()}"
# previous time steps are cached - no need to recompute key and value if they are static
if layer_state is not None:
saved_state = layer_state.get(cache_key, None)
query_states = self.query_proj(hidden_states) / (self.head_dim ** 0.5)
query_states = self._reshape(query_states, sequence_length, batch_size)
if not is_cross_attention:
# self-attention
key_states = self.key_proj(hidden_states)
key_states = self._reshape(key_states, -1, batch_size)
value_states = self.value_proj(hidden_states)
value_states = self._reshape(value_states, -1, batch_size)
elif saved_state is None:
# cross-attention without layer state
key_states = self.key_proj(key_value_states)
key_states = self._reshape(key_states, -1, batch_size)
value_states = self.value_proj(key_value_states)
value_states = self._reshape(value_states, -1, batch_size)
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.key_proj(key_value_states), -1, batch_size)
value_states = self._shape(self.value_proj(key_value_states), -1, batch_size)
else:
key_states = saved_state["prev_key_states"].view(batch_size * self.num_attn_heads, -1, self.head_dim)
value_states = saved_state["prev_value_states"].view(batch_size * self.num_attn_heads, -1, self.head_dim)
# self_attention
key_states = self._shape(self.key_proj(hidden_states), -1, batch_size)
value_states = self._shape(self.value_proj(hidden_states), -1, batch_size)
# Update cache
if is_cross_attention:
layer_state[cache_key] = {
"prev_key_states": key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
"prev_value_states": value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
}
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
key_sequence_length = key_states.size(1)
# project states into the correct shape
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
assert attn_weights.size() == (
batch_size * self.num_attn_heads,
sequence_length,
key_sequence_length,
), f"`attn_weights` should be of size {batch_size * self.num_attn_heads, sequence_length, key_sequence_length}, but is of size {attn_weights.shape}"
tgt_len,
src_len,
), f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size {attn_weights.shape}"
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if attention_mask is not None and attention_mask.dim() == 0:
@ -691,19 +694,21 @@ class ProphetNetAttention(nn.Module):
assert attention_mask is None or attention_mask.size() == (
self.num_attn_heads * batch_size,
1,
key_sequence_length,
), f"`attention_mask` should be `None` or of shape attention_mask.size() == {batch_size * self.num_attn_heads, 1, key_sequence_length}, but is {attention_mask.shape}"
src_len,
), f"`attention_mask` should be `None` or of shape attention_mask.size() == {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}"
if attention_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights + attention_mask
# need two reshapes to keep gradient at attention weights
attn_weights_reshaped = attn_weights.view(
batch_size, self.num_attn_heads, sequence_length, key_sequence_length
)
attn_weights = attn_weights_reshaped.view(
batch_size * self.num_attn_heads, sequence_length, key_sequence_length
)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(batch_size, self.num_attn_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(batch_size * self.num_attn_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(
@ -715,15 +720,20 @@ class ProphetNetAttention(nn.Module):
attn_output = torch.bmm(attn_probs, value_states)
assert attn_output.size() == (
batch_size * self.num_attn_heads,
sequence_length,
tgt_len,
self.head_dim,
), "`attn_output` should be of shape {batch_size * self.num_attn_heads, sequence_length, self.head_dim}, but is of shape {attn_output.size()}"
attn_output = attn_output.transpose(0, 1).contiguous().view(sequence_length, batch_size, hidden_size)
), "`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of shape {attn_output.size()}"
attn_output = (
attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, tgt_len, hidden_size)
)
attn_output = self.out_proj(attn_output)
attn_output = F.dropout(attn_output, p=self.dropout, training=self.training)
return attn_output, attn_weights_reshaped
return attn_output, attn_weights_reshaped, past_key_value
class ProphetNetFeedForward(nn.Module):
@ -779,8 +789,8 @@ class ProphetNetNgramSelfAttention(nn.Module):
# for onnx runtime
self.onnx_trace = False
def _reshape(self, tensor, first_dim, batch_size):
return tensor.reshape(first_dim, batch_size * self.num_attn_heads, self.head_dim).transpose(0, 1)
def _shape(self, tensor, seq_len, batch_size):
return tensor.view(batch_size, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()
def prepare_for_onnx_export_(self):
self.onnx_trace = True
@ -788,23 +798,20 @@ class ProphetNetNgramSelfAttention(nn.Module):
def forward(
self,
hidden_states,
layer_state=None,
past_key_value: Optional[Tuple[Tensor]] = None,
attention_mask=None,
extended_predict_attention_mask=None,
main_relative_position_buckets=None,
predict_relative_position_buckets=None,
position_ids=None,
):
sequence_length, batch_size, hidden_size = hidden_states.size()
batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
assert list(hidden_states.size()) == [
sequence_length,
batch_size,
ngram_sequence_length,
hidden_size,
], f"`hidden_states` should be of shape {sequence_length, batch_size, hidden_size}, but is of shape {hidden_states.shape}"
# key and value of previous time steps are cached
saved_state = layer_state.get("self", None)
], f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape {hidden_states.shape}"
# project
query_states = self.query_proj(hidden_states)
@ -815,12 +822,18 @@ class ProphetNetNgramSelfAttention(nn.Module):
query_states = query_states / (self.head_dim ** 0.5)
# reshape
query_states = self._reshape(query_states, sequence_length, batch_size)
key_states = self._reshape(key_states, -1, batch_size)
value_states = self._reshape(value_states, -1, batch_size)
query_states = self._shape(query_states, ngram_sequence_length, batch_size)
key_states = self._shape(key_states, -1, batch_size)
value_states = self._shape(value_states, -1, batch_size)
proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim)
query_states = query_states.view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
# chunk into main stream and predict stream
hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=0)
hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
query_states_list = query_states.chunk(1 + self.ngram, dim=1)
key_states_list = key_states.chunk(1 + self.ngram, dim=1)
@ -832,24 +845,20 @@ class ProphetNetNgramSelfAttention(nn.Module):
main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:]
# saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim)
if saved_state is not None:
prev_main_key_states = saved_state["prev_key_states"].view(
batch_size * self.num_attn_heads, -1, self.head_dim
)
if past_key_value is not None:
prev_main_key_states = past_key_value[0].view(batch_size * self.num_attn_heads, -1, self.head_dim)
main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1)
prev_main_value_states = saved_state["prev_value_states"].view(
batch_size * self.num_attn_heads, -1, self.head_dim
)
prev_main_value_states = past_key_value[1].view(batch_size * self.num_attn_heads, -1, self.head_dim)
main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1)
# Update cache
layer_state["self"] = {
"prev_key_states": main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
"prev_value_states": main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
}
past_key_value = (
main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim),
)
# get seq_length of main stream only
main_sequence_length = sequence_length // (1 + self.ngram)
sequence_length = ngram_sequence_length // (1 + self.ngram)
# MAIN-STREAM
# main attn weights
@ -871,18 +880,21 @@ class ProphetNetNgramSelfAttention(nn.Module):
).type_as(main_attn_weights)
main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
# project to attn_output
main_attn_output = torch.bmm(main_attn_probs, main_value_states)
# reshape so that num_heads dim is merged into last `head_dim` axis
main_attn_output = (
main_attn_output.transpose(0, 1).contiguous().view(1, main_sequence_length, batch_size, hidden_size)
main_attn_output.view(batch_size, self.num_attn_heads, sequence_length, self.head_dim)
.transpose(1, 2)
.reshape(batch_size, 1, sequence_length, hidden_size)
)
main_attn_output = self.out_proj(main_attn_output)
# PREDICT-STREAM
# [ngram, B*head, T, c]
predict_query_states = torch.cat(predict_query_states_list, 0).view(
self.ngram, -1, main_sequence_length, self.head_dim
self.ngram, -1, sequence_length, self.head_dim
)
# [ngram, B*head, 2*T, c]
predict_key_states = torch.cat(
@ -891,7 +903,7 @@ class ProphetNetNgramSelfAttention(nn.Module):
# [ngram, T, B, C]
predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view(
self.ngram, main_sequence_length, batch_size, hidden_size
self.ngram, sequence_length, batch_size, hidden_size
)
# [ngram, B*head, 2*T, c]
@ -911,7 +923,9 @@ class ProphetNetNgramSelfAttention(nn.Module):
predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
if extended_predict_attention_mask is not None:
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask
predict_attn_weights = predict_attn_weights + extended_predict_attention_mask.to(
predict_attn_weights.dtype
)
predict_attn_probs = softmax(
predict_attn_weights,
@ -919,35 +933,36 @@ class ProphetNetNgramSelfAttention(nn.Module):
onnx_trace=self.onnx_trace,
).type_as(predict_attn_weights)
predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training)
# project to attention output
# [ngram, B*head, T, c]
predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states))
# [ngram, T, B, C]
# reshape so that num_heads dim is merged into last `head_dim` axis
# [ngram, B, T, C]
predict_attn_output = (
predict_attn_output.transpose(1, 2)
.contiguous()
.view(self.ngram, main_sequence_length, batch_size, hidden_size)
predict_attn_output.view(self.ngram, batch_size, self.num_attn_heads, sequence_length, self.head_dim)
.permute(1, 0, 3, 2, 4)
.reshape(batch_size, self.ngram, sequence_length, hidden_size)
)
predict_attn_output = self.out_proj(predict_attn_output)
# concat to single attn output
# [1+ngram*T, B, C]
attn_output = torch.cat([main_attn_output, predict_attn_output], 0).view(-1, batch_size, hidden_size)
# [B, 1+ngram*T, C]
attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
# reshape into better form for `config.output_attentions`
main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, main_sequence_length, -1)
main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, main_sequence_length, -1
self.ngram, batch_size, self.num_attn_heads, sequence_length, -1
).transpose(0, 1)
attn_output = F.dropout(attn_output, p=self.dropout, training=self.training)
return attn_output, main_attn_probs, predict_attn_probs
return attn_output, main_attn_probs, predict_attn_probs, past_key_value
def get_main_relative_pos_embeddings(
self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
):
# input hidden_states [T,B,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1]
# input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1]
if main_relative_position_buckets is None:
batch_size, sequence_length = hidden_states.shape[:2]
@ -965,7 +980,6 @@ class ProphetNetNgramSelfAttention(nn.Module):
self.num_buckets, self.relative_max_distance, relative_positions, False
)
hidden_states = hidden_states.transpose(0, 1) # [B,T,C]
rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head]
rel_pos_embeddings = rel_pos_embeddings.view(
rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
@ -991,7 +1005,6 @@ class ProphetNetNgramSelfAttention(nn.Module):
self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
):
# input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None
sequence_length, batch_size = hidden_states.shape[1:3]
if predict_relative_position_buckets is None:
@ -1053,18 +1066,25 @@ class ProphetNetEncoderLayer(nn.Module):
self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
def forward(self, hidden_states, attention_mask):
def forward(self, hidden_states, attention_mask, output_attentions: bool = False):
# 1st residual block
attention_output, attn_weights = self.self_attn(
attention_output, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = self.self_attn_layer_norm(attention_output + hidden_states)
# 2nd residual block
feed_forward_output = self.feed_forward(hidden_states)
hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
return hidden_states, attn_weights
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class ProphetNetDecoderLayer(nn.Module):
@ -1090,21 +1110,23 @@ class ProphetNetDecoderLayer(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attn_mask=None,
layer_state=None,
attention_mask=None,
extended_predict_attention_mask=None,
main_relative_position_buckets=None,
predict_relative_position_buckets=None,
position_ids=None,
past_key_value=None,
use_cache: bool = True,
output_attentions: bool = False,
):
layer_state = layer_state if layer_state is not None else {}
# 1st residual block
ngram_attention_output, self_attn_weights, self_attn_weights_ngram = self.self_attn(
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn(
hidden_states=hidden_states,
layer_state=layer_state,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets,
@ -1113,28 +1135,36 @@ class ProphetNetDecoderLayer(nn.Module):
)
hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attn_weights = None
if encoder_hidden_states is not None:
# 2nd residual block
attention_output, cross_attn_weights = self.cross_attn(
attention_output, cross_attn_weights, cross_attn_present_key_value = self.cross_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states)
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# 3rd residual block
feed_forward_output = self.feed_forward(hidden_states)
hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
return (
hidden_states,
self_attn_weights,
self_attn_weights_ngram,
cross_attn_weights,
layer_state,
) # just self_attn weights for now, following t5, layer_state = cache for decoding
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
@add_start_docstrings(
@ -1223,21 +1253,37 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
hidden_states = inputs_embeds + position_embeddings
hidden_states = self.embeddings_layer_norm(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
hidden_states = hidden_states.transpose(0, 1) # B x T x C -> T x B x C
encoder_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for encoder_layer in self.layers:
if output_hidden_states:
hidden_states = hidden_states.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states + (hidden_states,)
hidden_states = hidden_states.transpose(0, 1)
hidden_states, attn_probs = encoder_layer(hidden_states, attention_mask=extended_attention_mask)
if output_attentions:
all_attentions = all_attentions + (attn_probs,)
hidden_states = hidden_states.transpose(0, 1)
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
extended_attention_mask,
)
else:
layer_outputs = encoder_layer(
hidden_states, attention_mask=extended_attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_hidden_states = encoder_hidden_states + (hidden_states,)
@ -1370,26 +1416,24 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
# add position embeddings
hidden_states = inputs_embeds + main_stream_pos_embed
hidden_states = hidden_states.transpose(0, 1)
ngram_embeddings = self.ngram_embeddings.weight
# prepare attention mask
if past_key_values is not None:
assert (
hidden_states.size(0) == 1
hidden_states.size(1) == 1
), "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1"
ngram_hidden_states = [
(ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).transpose(0, 1).repeat(1, batch_size, 1)
(ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).repeat(batch_size, 1, 1)
for ngram in range(self.ngram)
]
extended_attention_mask = None
extended_predict_attention_mask = None
else:
ngram_hidden_states = [
(ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).transpose(0, 1)
for ngram in range(self.ngram)
(ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) for ngram in range(self.ngram)
]
extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask)
extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask)
@ -1403,16 +1447,13 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
else:
extended_encoder_attention_mask = None
hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 0)
hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 1)
if self.embeddings_layer_norm:
hidden_states = self.embeddings_layer_norm(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
# init attentions, hidden_states and cache with empty tuples
all_main_stream_hidden_states = () if output_hidden_states else None
all_ngram_stream_hidden_states = () if output_hidden_states and self.config.ngram > 0 else None
@ -1425,47 +1466,75 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
# grad cannot be kept because tensor is sliced
all_main_stream_hidden_states += (hidden_states[:sequence_length].transpose(0, 1),)
all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
if self.config.ngram > 0:
all_ngram_stream_hidden_states += (hidden_states[sequence_length:].transpose(0, 1),)
all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
extended_attention_mask,
encoder_hidden_states,
extended_encoder_attention_mask,
extended_predict_attention_mask,
main_relative_position_buckets,
predict_relative_position_buckets,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=extended_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attn_mask=extended_encoder_attention_mask,
extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets,
predict_relative_position_buckets=predict_relative_position_buckets,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
layer_state = past_key_values[idx] if past_key_values is not None else None
(
hidden_states,
layer_self_attn,
layer_self_predict_attn_output,
layer_cross_attn,
layer_past,
) = decoder_layer(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attn_mask=extended_encoder_attention_mask,
layer_state=layer_state,
attention_mask=extended_attention_mask,
extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets,
predict_relative_position_buckets=predict_relative_position_buckets,
position_ids=position_ids,
)
if use_cache:
present_key_values += (layer_past,)
present_key_values += (layer_outputs[4 if output_attentions else 1],)
if output_attentions:
all_main_stream_attns += (layer_self_attn,)
all_ngram_stream_attns += (layer_self_predict_attn_output,)
all_main_stream_attns += (layer_outputs[1],)
all_ngram_stream_attns += (layer_outputs[2],)
if self.config.add_cross_attention:
all_cross_attns += (layer_cross_attn,)
all_cross_attns += (layer_outputs[3],)
if output_hidden_states:
all_main_stream_hidden_states += (hidden_states[:sequence_length].transpose(0, 1),)
all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
if self.config.ngram > 0:
all_ngram_stream_hidden_states += (hidden_states[sequence_length:].transpose(0, 1),)
all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
# split last_hidden_state for return
last_hidden_state = hidden_states[:sequence_length].transpose(0, 1)
last_hidden_state_ngram = hidden_states[sequence_length:].transpose(0, 1) if self.config.ngram > 0 else None
encoder_hidden_states = encoder_hidden_states.transpose(0, 1) if encoder_hidden_states is not None else None
last_hidden_state = hidden_states[:, :sequence_length]
last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None
if not return_dict:
return tuple(
@ -1516,7 +1585,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
return main_relative_buckets, predict_relative_buckets
def prepare_attention_mask(self, hidden_states, attention_mask):
seq_length, batch_size = hidden_states.shape[:2]
batch_size, seq_length = hidden_states.shape[:2]
# get causal mask
causal_mask = hidden_states.new(seq_length, seq_length).float().fill_(-float("inf"))
@ -1534,7 +1603,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype)
def prepare_predict_attention_mask(self, hidden_states, attention_mask):
seq_length, batch_size = hidden_states.shape[:2]
batch_size, seq_length = hidden_states.shape[:2]
# get causal mask
predict_causal_mask = ngram_attention_bias(
@ -1656,7 +1725,7 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
return_dict=return_dict,
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
@ -1856,21 +1925,14 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
return self._shift_right(labels)
@staticmethod
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
# this function reorders the cache for beam search
def _reorder_cache(cache_dict, beam_idx):
for k, key_value_states in cache_dict.items():
if key_value_states is not None:
cache_dict[k] = key_value_states.index_select(0, beam_idx)
return cache_dict
reordered_past = []
reordered_past = ()
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: _reorder_cache(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
reordered_past.append(layer_past_new)
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
def get_encoder(self):
@ -1995,7 +2057,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
outputs = self.prophetnet.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
@ -2080,21 +2142,11 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
}
@staticmethod
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache
def _reorder_cache(past, beam_idx):
# this function reorders the cache for beam search
def _reorder_cache(cache_dict, beam_idx):
for k, key_value_states in cache_dict.items():
if key_value_states is not None:
cache_dict[k] = key_value_states.index_select(0, beam_idx)
return cache_dict
reordered_past = []
reordered_past = ()
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: _reorder_cache(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
reordered_past.append(layer_past_new)
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past

View File

@ -243,7 +243,7 @@ class ProphetNetModelTester:
# There should be `num_layers` key value embeddings stored in decoder_past
self.parent.assertEqual(len(decoder_past), config.num_decoder_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
self.parent.assertEqual(len(decoder_past[0]), 2) # cross-attention + uni-directional self-attention
self.parent.assertEqual(len(decoder_past[0]), 4) # cross-attention + uni-directional self-attention
def create_and_check_with_lm_head(
self,