1055 lines
37 KiB
Python
1055 lines
37 KiB
Python
# MIT License
|
|
|
|
# Copyright (c) 2019 Yang Liu and the HuggingFace team
|
|
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
import copy
|
|
import math
|
|
|
|
import numpy as np
|
|
import torch
|
|
from configuration_bertabs import BertAbsConfig
|
|
from torch import nn
|
|
from torch.nn.init import xavier_uniform_
|
|
|
|
from transformers import BertConfig, BertModel, PreTrainedModel
|
|
|
|
|
|
MAX_SIZE = 5000
|
|
|
|
|
|
class BertAbsPreTrainedModel(PreTrainedModel):
|
|
config_class = BertAbsConfig
|
|
load_tf_weights = False
|
|
base_model_prefix = "bert"
|
|
|
|
|
|
class BertAbs(BertAbsPreTrainedModel):
|
|
def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None):
|
|
super().__init__(args)
|
|
self.args = args
|
|
self.bert = Bert()
|
|
|
|
# If pre-trained weights are passed for Bert, load these.
|
|
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
|
|
if load_bert_pretrained_extractive:
|
|
self.bert.model.load_state_dict(
|
|
{n[11:]: p for n, p in bert_extractive_checkpoint.items() if n.startswith("bert.model")},
|
|
strict=True,
|
|
)
|
|
|
|
self.vocab_size = self.bert.model.config.vocab_size
|
|
|
|
if args.max_pos > 512:
|
|
my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
|
|
my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
|
|
my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
|
|
None, :
|
|
].repeat(args.max_pos - 512, 1)
|
|
self.bert.model.embeddings.position_embeddings = my_pos_embeddings
|
|
tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
|
|
|
|
tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)
|
|
|
|
self.decoder = TransformerDecoder(
|
|
self.args.dec_layers,
|
|
self.args.dec_hidden_size,
|
|
heads=self.args.dec_heads,
|
|
d_ff=self.args.dec_ff_size,
|
|
dropout=self.args.dec_dropout,
|
|
embeddings=tgt_embeddings,
|
|
vocab_size=self.vocab_size,
|
|
)
|
|
|
|
gen_func = nn.LogSoftmax(dim=-1)
|
|
self.generator = nn.Sequential(nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func)
|
|
self.generator[0].weight = self.decoder.embeddings.weight
|
|
|
|
load_from_checkpoints = False if checkpoint is None else True
|
|
if load_from_checkpoints:
|
|
self.load_state_dict(checkpoint)
|
|
|
|
def init_weights(self):
|
|
for module in self.decoder.modules():
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
module.weight.data.normal_(mean=0.0, std=0.02)
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
module.bias.data.zero_()
|
|
for p in self.generator.parameters():
|
|
if p.dim() > 1:
|
|
xavier_uniform_(p)
|
|
else:
|
|
p.data.zero_()
|
|
|
|
def forward(
|
|
self,
|
|
encoder_input_ids,
|
|
decoder_input_ids,
|
|
token_type_ids,
|
|
encoder_attention_mask,
|
|
decoder_attention_mask,
|
|
):
|
|
encoder_output = self.bert(
|
|
input_ids=encoder_input_ids,
|
|
token_type_ids=token_type_ids,
|
|
attention_mask=encoder_attention_mask,
|
|
)
|
|
encoder_hidden_states = encoder_output[0]
|
|
dec_state = self.decoder.init_decoder_state(encoder_input_ids, encoder_hidden_states)
|
|
decoder_outputs, _ = self.decoder(decoder_input_ids[:, :-1], encoder_hidden_states, dec_state)
|
|
return decoder_outputs
|
|
|
|
|
|
class Bert(nn.Module):
|
|
"""This class is not really necessary and should probably disappear."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
config = BertConfig.from_pretrained("google-bert/bert-base-uncased")
|
|
self.model = BertModel(config)
|
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
|
|
self.eval()
|
|
with torch.no_grad():
|
|
encoder_outputs, _ = self.model(
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, **kwargs
|
|
)
|
|
return encoder_outputs
|
|
|
|
|
|
class TransformerDecoder(nn.Module):
|
|
"""
|
|
The Transformer decoder from "Attention is All You Need".
|
|
|
|
Args:
|
|
num_layers (int): number of encoder layers.
|
|
d_model (int): size of the model
|
|
heads (int): number of heads
|
|
d_ff (int): size of the inner FF layer
|
|
dropout (float): dropout parameters
|
|
embeddings (:obj:`onmt.modules.Embeddings`):
|
|
embeddings to use, should have positional encodings
|
|
attn_type (str): if using a separate copy attention
|
|
"""
|
|
|
|
def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, vocab_size):
|
|
super().__init__()
|
|
|
|
# Basic attributes.
|
|
self.decoder_type = "transformer"
|
|
self.num_layers = num_layers
|
|
self.embeddings = embeddings
|
|
self.pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim)
|
|
|
|
# Build TransformerDecoder.
|
|
self.transformer_layers = nn.ModuleList(
|
|
[TransformerDecoderLayer(d_model, heads, d_ff, dropout) for _ in range(num_layers)]
|
|
)
|
|
|
|
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
|
|
|
# forward(input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask)
|
|
# def forward(self, input_ids, state, attention_mask=None, memory_lengths=None,
|
|
# step=None, cache=None, encoder_attention_mask=None, encoder_hidden_states=None, memory_masks=None):
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
encoder_hidden_states=None,
|
|
state=None,
|
|
attention_mask=None,
|
|
memory_lengths=None,
|
|
step=None,
|
|
cache=None,
|
|
encoder_attention_mask=None,
|
|
):
|
|
"""
|
|
See :obj:`onmt.modules.RNNDecoderBase.forward()`
|
|
memory_bank = encoder_hidden_states
|
|
"""
|
|
# Name conversion
|
|
tgt = input_ids
|
|
memory_bank = encoder_hidden_states
|
|
memory_mask = encoder_attention_mask
|
|
|
|
# src_words = state.src
|
|
src_words = state.src
|
|
src_batch, src_len = src_words.size()
|
|
|
|
padding_idx = self.embeddings.padding_idx
|
|
|
|
# Decoder padding mask
|
|
tgt_words = tgt
|
|
tgt_batch, tgt_len = tgt_words.size()
|
|
tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len)
|
|
|
|
# Encoder padding mask
|
|
if memory_mask is not None:
|
|
src_len = memory_mask.size(-1)
|
|
src_pad_mask = memory_mask.expand(src_batch, tgt_len, src_len)
|
|
else:
|
|
src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1).expand(src_batch, tgt_len, src_len)
|
|
|
|
# Pass through the embeddings
|
|
emb = self.embeddings(input_ids)
|
|
output = self.pos_emb(emb, step)
|
|
assert emb.dim() == 3 # len x batch x embedding_dim
|
|
|
|
if state.cache is None:
|
|
saved_inputs = []
|
|
|
|
for i in range(self.num_layers):
|
|
prev_layer_input = None
|
|
if state.cache is None:
|
|
if state.previous_input is not None:
|
|
prev_layer_input = state.previous_layer_inputs[i]
|
|
|
|
output, all_input = self.transformer_layers[i](
|
|
output,
|
|
memory_bank,
|
|
src_pad_mask,
|
|
tgt_pad_mask,
|
|
previous_input=prev_layer_input,
|
|
layer_cache=state.cache["layer_{}".format(i)] if state.cache is not None else None,
|
|
step=step,
|
|
)
|
|
if state.cache is None:
|
|
saved_inputs.append(all_input)
|
|
|
|
if state.cache is None:
|
|
saved_inputs = torch.stack(saved_inputs)
|
|
|
|
output = self.layer_norm(output)
|
|
|
|
if state.cache is None:
|
|
state = state.update_state(tgt, saved_inputs)
|
|
|
|
# Decoders in transformers return a tuple. Beam search will fail
|
|
# if we don't follow this convention.
|
|
return output, state # , state
|
|
|
|
def init_decoder_state(self, src, memory_bank, with_cache=False):
|
|
"""Init decoder state"""
|
|
state = TransformerDecoderState(src)
|
|
if with_cache:
|
|
state._init_cache(memory_bank, self.num_layers)
|
|
return state
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
def __init__(self, dropout, dim, max_len=5000):
|
|
pe = torch.zeros(max_len, dim)
|
|
position = torch.arange(0, max_len).unsqueeze(1)
|
|
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim)))
|
|
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
|
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
|
pe = pe.unsqueeze(0)
|
|
super().__init__()
|
|
self.register_buffer("pe", pe)
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
self.dim = dim
|
|
|
|
def forward(self, emb, step=None):
|
|
emb = emb * math.sqrt(self.dim)
|
|
if step:
|
|
emb = emb + self.pe[:, step][:, None, :]
|
|
|
|
else:
|
|
emb = emb + self.pe[:, : emb.size(1)]
|
|
emb = self.dropout(emb)
|
|
return emb
|
|
|
|
def get_emb(self, emb):
|
|
return self.pe[:, : emb.size(1)]
|
|
|
|
|
|
class TransformerDecoderLayer(nn.Module):
|
|
"""
|
|
Args:
|
|
d_model (int): the dimension of keys/values/queries in
|
|
MultiHeadedAttention, also the input size of
|
|
the first-layer of the PositionwiseFeedForward.
|
|
heads (int): the number of heads for MultiHeadedAttention.
|
|
d_ff (int): the second-layer of the PositionwiseFeedForward.
|
|
dropout (float): dropout probability(0-1.0).
|
|
self_attn_type (string): type of self-attention scaled-dot, average
|
|
"""
|
|
|
|
def __init__(self, d_model, heads, d_ff, dropout):
|
|
super().__init__()
|
|
|
|
self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)
|
|
|
|
self.context_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)
|
|
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
|
|
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
|
|
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
|
|
self.drop = nn.Dropout(dropout)
|
|
mask = self._get_attn_subsequent_mask(MAX_SIZE)
|
|
# Register self.mask as a saved_state in TransformerDecoderLayer, so
|
|
# it gets TransformerDecoderLayer's cuda behavior automatically.
|
|
self.register_buffer("mask", mask)
|
|
|
|
def forward(
|
|
self,
|
|
inputs,
|
|
memory_bank,
|
|
src_pad_mask,
|
|
tgt_pad_mask,
|
|
previous_input=None,
|
|
layer_cache=None,
|
|
step=None,
|
|
):
|
|
"""
|
|
Args:
|
|
inputs (`FloatTensor`): `[batch_size x 1 x model_dim]`
|
|
memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]`
|
|
src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]`
|
|
tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]`
|
|
|
|
Returns:
|
|
(`FloatTensor`, `FloatTensor`, `FloatTensor`):
|
|
|
|
* output `[batch_size x 1 x model_dim]`
|
|
* attn `[batch_size x 1 x src_len]`
|
|
* all_input `[batch_size x current_step x model_dim]`
|
|
|
|
"""
|
|
dec_mask = torch.gt(tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0)
|
|
input_norm = self.layer_norm_1(inputs)
|
|
all_input = input_norm
|
|
if previous_input is not None:
|
|
all_input = torch.cat((previous_input, input_norm), dim=1)
|
|
dec_mask = None
|
|
|
|
query = self.self_attn(
|
|
all_input,
|
|
all_input,
|
|
input_norm,
|
|
mask=dec_mask,
|
|
layer_cache=layer_cache,
|
|
type="self",
|
|
)
|
|
|
|
query = self.drop(query) + inputs
|
|
|
|
query_norm = self.layer_norm_2(query)
|
|
mid = self.context_attn(
|
|
memory_bank,
|
|
memory_bank,
|
|
query_norm,
|
|
mask=src_pad_mask,
|
|
layer_cache=layer_cache,
|
|
type="context",
|
|
)
|
|
output = self.feed_forward(self.drop(mid) + query)
|
|
|
|
return output, all_input
|
|
# return output
|
|
|
|
def _get_attn_subsequent_mask(self, size):
|
|
"""
|
|
Get an attention mask to avoid using the subsequent info.
|
|
|
|
Args:
|
|
size: int
|
|
|
|
Returns:
|
|
(`LongTensor`):
|
|
|
|
* subsequent_mask `[1 x size x size]`
|
|
"""
|
|
attn_shape = (1, size, size)
|
|
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8")
|
|
subsequent_mask = torch.from_numpy(subsequent_mask)
|
|
return subsequent_mask
|
|
|
|
|
|
class MultiHeadedAttention(nn.Module):
|
|
"""
|
|
Multi-Head Attention module from
|
|
"Attention is All You Need"
|
|
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`.
|
|
|
|
Similar to standard `dot` attention but uses
|
|
multiple attention distributions simulataneously
|
|
to select relevant items.
|
|
|
|
.. mermaid::
|
|
|
|
graph BT
|
|
A[key]
|
|
B[value]
|
|
C[query]
|
|
O[output]
|
|
subgraph Attn
|
|
D[Attn 1]
|
|
E[Attn 2]
|
|
F[Attn N]
|
|
end
|
|
A --> D
|
|
C --> D
|
|
A --> E
|
|
C --> E
|
|
A --> F
|
|
C --> F
|
|
D --> O
|
|
E --> O
|
|
F --> O
|
|
B --> O
|
|
|
|
Also includes several additional tricks.
|
|
|
|
Args:
|
|
head_count (int): number of parallel heads
|
|
model_dim (int): the dimension of keys/values/queries,
|
|
must be divisible by head_count
|
|
dropout (float): dropout parameter
|
|
"""
|
|
|
|
def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True):
|
|
assert model_dim % head_count == 0
|
|
self.dim_per_head = model_dim // head_count
|
|
self.model_dim = model_dim
|
|
|
|
super().__init__()
|
|
self.head_count = head_count
|
|
|
|
self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head)
|
|
self.linear_values = nn.Linear(model_dim, head_count * self.dim_per_head)
|
|
self.linear_query = nn.Linear(model_dim, head_count * self.dim_per_head)
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.use_final_linear = use_final_linear
|
|
if self.use_final_linear:
|
|
self.final_linear = nn.Linear(model_dim, model_dim)
|
|
|
|
def forward(
|
|
self,
|
|
key,
|
|
value,
|
|
query,
|
|
mask=None,
|
|
layer_cache=None,
|
|
type=None,
|
|
predefined_graph_1=None,
|
|
):
|
|
"""
|
|
Compute the context vector and the attention vectors.
|
|
|
|
Args:
|
|
key (`FloatTensor`): set of `key_len`
|
|
key vectors `[batch, key_len, dim]`
|
|
value (`FloatTensor`): set of `key_len`
|
|
value vectors `[batch, key_len, dim]`
|
|
query (`FloatTensor`): set of `query_len`
|
|
query vectors `[batch, query_len, dim]`
|
|
mask: binary mask indicating which keys have
|
|
non-zero attention `[batch, query_len, key_len]`
|
|
Returns:
|
|
(`FloatTensor`, `FloatTensor`) :
|
|
|
|
* output context vectors `[batch, query_len, dim]`
|
|
* one of the attention vectors `[batch, query_len, key_len]`
|
|
"""
|
|
batch_size = key.size(0)
|
|
dim_per_head = self.dim_per_head
|
|
head_count = self.head_count
|
|
|
|
def shape(x):
|
|
"""projection"""
|
|
return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2)
|
|
|
|
def unshape(x):
|
|
"""compute context"""
|
|
return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head)
|
|
|
|
# 1) Project key, value, and query.
|
|
if layer_cache is not None:
|
|
if type == "self":
|
|
query, key, value = (
|
|
self.linear_query(query),
|
|
self.linear_keys(query),
|
|
self.linear_values(query),
|
|
)
|
|
|
|
key = shape(key)
|
|
value = shape(value)
|
|
|
|
if layer_cache is not None:
|
|
device = key.device
|
|
if layer_cache["self_keys"] is not None:
|
|
key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2)
|
|
if layer_cache["self_values"] is not None:
|
|
value = torch.cat((layer_cache["self_values"].to(device), value), dim=2)
|
|
layer_cache["self_keys"] = key
|
|
layer_cache["self_values"] = value
|
|
elif type == "context":
|
|
query = self.linear_query(query)
|
|
if layer_cache is not None:
|
|
if layer_cache["memory_keys"] is None:
|
|
key, value = self.linear_keys(key), self.linear_values(value)
|
|
key = shape(key)
|
|
value = shape(value)
|
|
else:
|
|
key, value = (
|
|
layer_cache["memory_keys"],
|
|
layer_cache["memory_values"],
|
|
)
|
|
layer_cache["memory_keys"] = key
|
|
layer_cache["memory_values"] = value
|
|
else:
|
|
key, value = self.linear_keys(key), self.linear_values(value)
|
|
key = shape(key)
|
|
value = shape(value)
|
|
else:
|
|
key = self.linear_keys(key)
|
|
value = self.linear_values(value)
|
|
query = self.linear_query(query)
|
|
key = shape(key)
|
|
value = shape(value)
|
|
|
|
query = shape(query)
|
|
|
|
# 2) Calculate and scale scores.
|
|
query = query / math.sqrt(dim_per_head)
|
|
scores = torch.matmul(query, key.transpose(2, 3))
|
|
|
|
if mask is not None:
|
|
mask = mask.unsqueeze(1).expand_as(scores)
|
|
scores = scores.masked_fill(mask, -1e18)
|
|
|
|
# 3) Apply attention dropout and compute context vectors.
|
|
|
|
attn = self.softmax(scores)
|
|
|
|
if predefined_graph_1 is not None:
|
|
attn_masked = attn[:, -1] * predefined_graph_1
|
|
attn_masked = attn_masked / (torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9)
|
|
|
|
attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1)
|
|
|
|
drop_attn = self.dropout(attn)
|
|
if self.use_final_linear:
|
|
context = unshape(torch.matmul(drop_attn, value))
|
|
output = self.final_linear(context)
|
|
return output
|
|
else:
|
|
context = torch.matmul(drop_attn, value)
|
|
return context
|
|
|
|
|
|
class DecoderState(object):
|
|
"""Interface for grouping together the current state of a recurrent
|
|
decoder. In the simplest case just represents the hidden state of
|
|
the model. But can also be used for implementing various forms of
|
|
input_feeding and non-recurrent models.
|
|
|
|
Modules need to implement this to utilize beam search decoding.
|
|
"""
|
|
|
|
def detach(self):
|
|
"""Need to document this"""
|
|
self.hidden = tuple([_.detach() for _ in self.hidden])
|
|
self.input_feed = self.input_feed.detach()
|
|
|
|
def beam_update(self, idx, positions, beam_size):
|
|
"""Need to document this"""
|
|
for e in self._all:
|
|
sizes = e.size()
|
|
br = sizes[1]
|
|
if len(sizes) == 3:
|
|
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[:, :, idx]
|
|
else:
|
|
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2], sizes[3])[:, :, idx]
|
|
|
|
sent_states.data.copy_(sent_states.data.index_select(1, positions))
|
|
|
|
def map_batch_fn(self, fn):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class TransformerDecoderState(DecoderState):
|
|
"""Transformer Decoder state base class"""
|
|
|
|
def __init__(self, src):
|
|
"""
|
|
Args:
|
|
src (FloatTensor): a sequence of source words tensors
|
|
with optional feature tensors, of size (len x batch).
|
|
"""
|
|
self.src = src
|
|
self.previous_input = None
|
|
self.previous_layer_inputs = None
|
|
self.cache = None
|
|
|
|
@property
|
|
def _all(self):
|
|
"""
|
|
Contains attributes that need to be updated in self.beam_update().
|
|
"""
|
|
if self.previous_input is not None and self.previous_layer_inputs is not None:
|
|
return (self.previous_input, self.previous_layer_inputs, self.src)
|
|
else:
|
|
return (self.src,)
|
|
|
|
def detach(self):
|
|
if self.previous_input is not None:
|
|
self.previous_input = self.previous_input.detach()
|
|
if self.previous_layer_inputs is not None:
|
|
self.previous_layer_inputs = self.previous_layer_inputs.detach()
|
|
self.src = self.src.detach()
|
|
|
|
def update_state(self, new_input, previous_layer_inputs):
|
|
state = TransformerDecoderState(self.src)
|
|
state.previous_input = new_input
|
|
state.previous_layer_inputs = previous_layer_inputs
|
|
return state
|
|
|
|
def _init_cache(self, memory_bank, num_layers):
|
|
self.cache = {}
|
|
|
|
for l in range(num_layers):
|
|
layer_cache = {"memory_keys": None, "memory_values": None}
|
|
layer_cache["self_keys"] = None
|
|
layer_cache["self_values"] = None
|
|
self.cache["layer_{}".format(l)] = layer_cache
|
|
|
|
def repeat_beam_size_times(self, beam_size):
|
|
"""Repeat beam_size times along batch dimension."""
|
|
self.src = self.src.data.repeat(1, beam_size, 1)
|
|
|
|
def map_batch_fn(self, fn):
|
|
def _recursive_map(struct, batch_dim=0):
|
|
for k, v in struct.items():
|
|
if v is not None:
|
|
if isinstance(v, dict):
|
|
_recursive_map(v)
|
|
else:
|
|
struct[k] = fn(v, batch_dim)
|
|
|
|
self.src = fn(self.src, 0)
|
|
if self.cache is not None:
|
|
_recursive_map(self.cache)
|
|
|
|
|
|
def gelu(x):
|
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
|
|
|
|
|
class PositionwiseFeedForward(nn.Module):
|
|
"""A two-layer Feed-Forward-Network with residual layer norm.
|
|
|
|
Args:
|
|
d_model (int): the size of input for the first-layer of the FFN.
|
|
d_ff (int): the hidden layer size of the second-layer
|
|
of the FNN.
|
|
dropout (float): dropout probability in :math:`[0, 1)`.
|
|
"""
|
|
|
|
def __init__(self, d_model, d_ff, dropout=0.1):
|
|
super().__init__()
|
|
self.w_1 = nn.Linear(d_model, d_ff)
|
|
self.w_2 = nn.Linear(d_ff, d_model)
|
|
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
|
self.actv = gelu
|
|
self.dropout_1 = nn.Dropout(dropout)
|
|
self.dropout_2 = nn.Dropout(dropout)
|
|
|
|
def forward(self, x):
|
|
inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x))))
|
|
output = self.dropout_2(self.w_2(inter))
|
|
return output + x
|
|
|
|
|
|
#
|
|
# TRANSLATOR
|
|
# The following code is used to generate summaries using the
|
|
# pre-trained weights and beam search.
|
|
#
|
|
|
|
|
|
def build_predictor(args, tokenizer, symbols, model, logger=None):
|
|
# we should be able to refactor the global scorer a lot
|
|
scorer = GNMTGlobalScorer(args.alpha, length_penalty="wu")
|
|
translator = Translator(args, model, tokenizer, symbols, global_scorer=scorer, logger=logger)
|
|
return translator
|
|
|
|
|
|
class GNMTGlobalScorer(object):
|
|
"""
|
|
NMT re-ranking score from
|
|
"Google's Neural Machine Translation System" :cite:`wu2016google`
|
|
|
|
Args:
|
|
alpha (float): length parameter
|
|
beta (float): coverage parameter
|
|
"""
|
|
|
|
def __init__(self, alpha, length_penalty):
|
|
self.alpha = alpha
|
|
penalty_builder = PenaltyBuilder(length_penalty)
|
|
self.length_penalty = penalty_builder.length_penalty()
|
|
|
|
def score(self, beam, logprobs):
|
|
"""
|
|
Rescores a prediction based on penalty functions
|
|
"""
|
|
normalized_probs = self.length_penalty(beam, logprobs, self.alpha)
|
|
return normalized_probs
|
|
|
|
|
|
class PenaltyBuilder(object):
|
|
"""
|
|
Returns the Length and Coverage Penalty function for Beam Search.
|
|
|
|
Args:
|
|
length_pen (str): option name of length pen
|
|
cov_pen (str): option name of cov pen
|
|
"""
|
|
|
|
def __init__(self, length_pen):
|
|
self.length_pen = length_pen
|
|
|
|
def length_penalty(self):
|
|
if self.length_pen == "wu":
|
|
return self.length_wu
|
|
elif self.length_pen == "avg":
|
|
return self.length_average
|
|
else:
|
|
return self.length_none
|
|
|
|
"""
|
|
Below are all the different penalty terms implemented so far
|
|
"""
|
|
|
|
def length_wu(self, beam, logprobs, alpha=0.0):
|
|
"""
|
|
NMT length re-ranking score from
|
|
"Google's Neural Machine Translation System" :cite:`wu2016google`.
|
|
"""
|
|
|
|
modifier = ((5 + len(beam.next_ys)) ** alpha) / ((5 + 1) ** alpha)
|
|
return logprobs / modifier
|
|
|
|
def length_average(self, beam, logprobs, alpha=0.0):
|
|
"""
|
|
Returns the average probability of tokens in a sequence.
|
|
"""
|
|
return logprobs / len(beam.next_ys)
|
|
|
|
def length_none(self, beam, logprobs, alpha=0.0, beta=0.0):
|
|
"""
|
|
Returns unmodified scores.
|
|
"""
|
|
return logprobs
|
|
|
|
|
|
class Translator(object):
|
|
"""
|
|
Uses a model to translate a batch of sentences.
|
|
|
|
Args:
|
|
model (:obj:`onmt.modules.NMTModel`):
|
|
NMT model to use for translation
|
|
fields (dict of Fields): data fields
|
|
beam_size (int): size of beam to use
|
|
n_best (int): number of translations produced
|
|
max_length (int): maximum length output to produce
|
|
global_scores (:obj:`GlobalScorer`):
|
|
object to rescore final translations
|
|
copy_attn (bool): use copy attention during translation
|
|
beam_trace (bool): trace beam search for debugging
|
|
logger(logging.Logger): logger.
|
|
"""
|
|
|
|
def __init__(self, args, model, vocab, symbols, global_scorer=None, logger=None):
|
|
self.logger = logger
|
|
|
|
self.args = args
|
|
self.model = model
|
|
self.generator = self.model.generator
|
|
self.vocab = vocab
|
|
self.symbols = symbols
|
|
self.start_token = symbols["BOS"]
|
|
self.end_token = symbols["EOS"]
|
|
|
|
self.global_scorer = global_scorer
|
|
self.beam_size = args.beam_size
|
|
self.min_length = args.min_length
|
|
self.max_length = args.max_length
|
|
|
|
def translate(self, batch, step, attn_debug=False):
|
|
"""Generates summaries from one batch of data."""
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
batch_data = self.translate_batch(batch)
|
|
translations = self.from_batch(batch_data)
|
|
return translations
|
|
|
|
def translate_batch(self, batch, fast=False):
|
|
"""
|
|
Translate a batch of sentences.
|
|
|
|
Mostly a wrapper around :obj:`Beam`.
|
|
|
|
Args:
|
|
batch (:obj:`Batch`): a batch from a dataset object
|
|
fast (bool): enables fast beam search (may not support all features)
|
|
"""
|
|
with torch.no_grad():
|
|
return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length)
|
|
|
|
# Where the beam search lives
|
|
# I have no idea why it is being called from the method above
|
|
def _fast_translate_batch(self, batch, max_length, min_length=0):
|
|
"""Beam Search using the encoder inputs contained in `batch`."""
|
|
|
|
# The batch object is funny
|
|
# Instead of just looking at the size of the arguments we encapsulate
|
|
# a size argument.
|
|
# Where is it defined?
|
|
beam_size = self.beam_size
|
|
batch_size = batch.batch_size
|
|
src = batch.src
|
|
segs = batch.segs
|
|
mask_src = batch.mask_src
|
|
|
|
src_features = self.model.bert(src, segs, mask_src)
|
|
dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True)
|
|
device = src_features.device
|
|
|
|
# Tile states and memory beam_size times.
|
|
dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim))
|
|
src_features = tile(src_features, beam_size, dim=0)
|
|
batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)
|
|
beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device)
|
|
alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device)
|
|
|
|
# Give full probability to the first beam on the first step.
|
|
topk_log_probs = torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size)
|
|
|
|
# Structure that holds finished hypotheses.
|
|
hypotheses = [[] for _ in range(batch_size)] # noqa: F812
|
|
|
|
results = {}
|
|
results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812
|
|
results["scores"] = [[] for _ in range(batch_size)] # noqa: F812
|
|
results["gold_score"] = [0] * batch_size
|
|
results["batch"] = batch
|
|
|
|
for step in range(max_length):
|
|
decoder_input = alive_seq[:, -1].view(1, -1)
|
|
|
|
# Decoder forward.
|
|
decoder_input = decoder_input.transpose(0, 1)
|
|
|
|
dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step)
|
|
|
|
# Generator forward.
|
|
log_probs = self.generator(dec_out.transpose(0, 1).squeeze(0))
|
|
vocab_size = log_probs.size(-1)
|
|
|
|
if step < min_length:
|
|
log_probs[:, self.end_token] = -1e20
|
|
|
|
# Multiply probs by the beam probability.
|
|
log_probs += topk_log_probs.view(-1).unsqueeze(1)
|
|
|
|
alpha = self.global_scorer.alpha
|
|
length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha
|
|
|
|
# Flatten probs into a list of possibilities.
|
|
curr_scores = log_probs / length_penalty
|
|
|
|
if self.args.block_trigram:
|
|
cur_len = alive_seq.size(1)
|
|
if cur_len > 3:
|
|
for i in range(alive_seq.size(0)):
|
|
fail = False
|
|
words = [int(w) for w in alive_seq[i]]
|
|
words = [self.vocab.ids_to_tokens[w] for w in words]
|
|
words = " ".join(words).replace(" ##", "").split()
|
|
if len(words) <= 3:
|
|
continue
|
|
trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)]
|
|
trigram = tuple(trigrams[-1])
|
|
if trigram in trigrams[:-1]:
|
|
fail = True
|
|
if fail:
|
|
curr_scores[i] = -10e20
|
|
|
|
curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
|
|
topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)
|
|
|
|
# Recover log probs.
|
|
topk_log_probs = topk_scores * length_penalty
|
|
|
|
# Resolve beam origin and true word ids.
|
|
topk_beam_index = topk_ids.div(vocab_size)
|
|
topk_ids = topk_ids.fmod(vocab_size)
|
|
|
|
# Map beam_index to batch_index in the flat representation.
|
|
batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1)
|
|
select_indices = batch_index.view(-1)
|
|
|
|
# Append last prediction.
|
|
alive_seq = torch.cat([alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1)
|
|
|
|
is_finished = topk_ids.eq(self.end_token)
|
|
if step + 1 == max_length:
|
|
is_finished.fill_(1)
|
|
# End condition is top beam is finished.
|
|
end_condition = is_finished[:, 0].eq(1)
|
|
# Save finished hypotheses.
|
|
if is_finished.any():
|
|
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
|
|
for i in range(is_finished.size(0)):
|
|
b = batch_offset[i]
|
|
if end_condition[i]:
|
|
is_finished[i].fill_(1)
|
|
finished_hyp = is_finished[i].nonzero().view(-1)
|
|
# Store finished hypotheses for this batch.
|
|
for j in finished_hyp:
|
|
hypotheses[b].append((topk_scores[i, j], predictions[i, j, 1:]))
|
|
# If the batch reached the end, save the n_best hypotheses.
|
|
if end_condition[i]:
|
|
best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True)
|
|
score, pred = best_hyp[0]
|
|
|
|
results["scores"][b].append(score)
|
|
results["predictions"][b].append(pred)
|
|
non_finished = end_condition.eq(0).nonzero().view(-1)
|
|
# If all sentences are translated, no need to go further.
|
|
if len(non_finished) == 0:
|
|
break
|
|
# Remove finished batches for the next step.
|
|
topk_log_probs = topk_log_probs.index_select(0, non_finished)
|
|
batch_index = batch_index.index_select(0, non_finished)
|
|
batch_offset = batch_offset.index_select(0, non_finished)
|
|
alive_seq = predictions.index_select(0, non_finished).view(-1, alive_seq.size(-1))
|
|
# Reorder states.
|
|
select_indices = batch_index.view(-1)
|
|
src_features = src_features.index_select(0, select_indices)
|
|
dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices))
|
|
|
|
return results
|
|
|
|
def from_batch(self, translation_batch):
|
|
batch = translation_batch["batch"]
|
|
assert len(translation_batch["gold_score"]) == len(translation_batch["predictions"])
|
|
batch_size = batch.batch_size
|
|
|
|
preds, _, _, tgt_str, src = (
|
|
translation_batch["predictions"],
|
|
translation_batch["scores"],
|
|
translation_batch["gold_score"],
|
|
batch.tgt_str,
|
|
batch.src,
|
|
)
|
|
|
|
translations = []
|
|
for b in range(batch_size):
|
|
pred_sents = self.vocab.convert_ids_to_tokens([int(n) for n in preds[b][0]])
|
|
pred_sents = " ".join(pred_sents).replace(" ##", "")
|
|
gold_sent = " ".join(tgt_str[b].split())
|
|
raw_src = [self.vocab.ids_to_tokens[int(t)] for t in src[b]][:500]
|
|
raw_src = " ".join(raw_src)
|
|
translation = (pred_sents, gold_sent, raw_src)
|
|
translations.append(translation)
|
|
|
|
return translations
|
|
|
|
|
|
def tile(x, count, dim=0):
|
|
"""
|
|
Tiles x on dimension dim count times.
|
|
"""
|
|
perm = list(range(len(x.size())))
|
|
if dim != 0:
|
|
perm[0], perm[dim] = perm[dim], perm[0]
|
|
x = x.permute(perm).contiguous()
|
|
out_size = list(x.size())
|
|
out_size[0] *= count
|
|
batch = x.size(0)
|
|
x = x.view(batch, -1).transpose(0, 1).repeat(count, 1).transpose(0, 1).contiguous().view(*out_size)
|
|
if dim != 0:
|
|
x = x.permute(perm).contiguous()
|
|
return x
|
|
|
|
|
|
#
|
|
# Optimizer for training. We keep this here in case we want to add
|
|
# a finetuning script.
|
|
#
|
|
|
|
|
|
class BertSumOptimizer(object):
|
|
"""Specific optimizer for BertSum.
|
|
|
|
As described in [1], the authors fine-tune BertSum for abstractive
|
|
summarization using two Adam Optimizers with different warm-up steps and
|
|
learning rate. They also use a custom learning rate scheduler.
|
|
|
|
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
|
arXiv preprint arXiv:1908.08345 (2019).
|
|
"""
|
|
|
|
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8):
|
|
self.encoder = model.encoder
|
|
self.decoder = model.decoder
|
|
self.lr = lr
|
|
self.warmup_steps = warmup_steps
|
|
|
|
self.optimizers = {
|
|
"encoder": torch.optim.Adam(
|
|
model.encoder.parameters(),
|
|
lr=lr["encoder"],
|
|
betas=(beta_1, beta_2),
|
|
eps=eps,
|
|
),
|
|
"decoder": torch.optim.Adam(
|
|
model.decoder.parameters(),
|
|
lr=lr["decoder"],
|
|
betas=(beta_1, beta_2),
|
|
eps=eps,
|
|
),
|
|
}
|
|
|
|
self._step = 0
|
|
self.current_learning_rates = {}
|
|
|
|
def _update_rate(self, stack):
|
|
return self.lr[stack] * min(self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5))
|
|
|
|
def zero_grad(self):
|
|
self.optimizer_decoder.zero_grad()
|
|
self.optimizer_encoder.zero_grad()
|
|
|
|
def step(self):
|
|
self._step += 1
|
|
for stack, optimizer in self.optimizers.items():
|
|
new_rate = self._update_rate(stack)
|
|
for param_group in optimizer.param_groups:
|
|
param_group["lr"] = new_rate
|
|
optimizer.step()
|
|
self.current_learning_rates[stack] = new_rate
|