added beam_search generation for tf 2.0

This commit is contained in:
patrickvonplaten 2020-03-04 00:32:07 +01:00 committed by Patrick von Platen
parent 34de670dbe
commit 61fef6e957
2 changed files with 307 additions and 20 deletions

View File

@ -142,7 +142,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# # initialize all new embeddings (in particular added tokens)
# self._init_weights(new_embeddings)
# # Copy word embeddings from the previous weights
# # Copy token embeddings from the previous weights
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
@ -557,6 +557,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
else:
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
if do_sample is False:
if num_beams == 1:
# no_beam_search greedy generation conditions
assert (
num_return_sequences == 1
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
else:
# beam_search greedy generation conditions
assert (
num_beams >= num_return_sequences
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
if pad_token_id is None and eos_token_ids is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
@ -567,7 +580,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
cur_len = shape_list(input_ids)[1]
vocab_size = self.config.vocab_size
if num_return_sequences != 1:
if num_return_sequences != 1 and do_sample:
# Expand input to num return sequences
input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_return_sequences, cur_len))
effective_batch_size = batch_size * num_return_sequences
@ -588,6 +601,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
pad_token_id,
eos_token_ids,
effective_batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
@ -627,19 +641,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
All returned sequence are generated independantly.
"""
def _create_next_token_logits_penalties(input_ids, logits):
# create logit penalties for already seen input_ids
token_penalties = np.ones(shape_list(logits))
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
for i, prev_input_id in enumerate(prev_input_ids):
logit_penalized = logits[i].numpy()[prev_input_id]
# if previous logit score is < 0 then multiply repetition penalty else divide
logit_penalized[logit_penalized < 0] = repetition_penalty
logit_penalized[logit_penalized > 0] = 1 / repetition_penalty
np.put(token_penalties[i], prev_input_id, logit_penalized)
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
# current position / max lengths / length of generated sentences / unfinished sentences
# length of generated sentences / unfinished sentences
unfinished_sents = tf.ones_like(input_ids[:, 0])
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
@ -656,7 +658,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits)
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if do_sample:
@ -738,11 +740,228 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
pad_token_id,
eos_token_ids,
batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
):
pass
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_beams, cur_len))
input_ids = tf.reshape(input_ids, (batch_size * num_beams, cur_len)) # (batch_size * num_beams, cur_len)
# generated hypotheses
generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
]
# scores for each sentence in the beam
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
beam_scores = tf.reshape(tf.concat([beam_scores_begin, beam_scores_end], -1), (batch_size * num_beams,))
# cache compute states
past = None
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Top-p/top-k filtering
next_token_logits = tf_top_k_top_p_filtering(
next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
next_tokens = tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=2) # (batch_size * num_beams, vocab_size)
# Compute next scores
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
_scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2)
next_scores = _scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, 2)) # (batch_size * num_beams, 2)
# Match shape of greedy beam search
next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
else:
# do greedy beam search
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
next_scores = scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, vocab_size)) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = tf.reshape(next_scores, (batch_size, num_beams * vocab_size)) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
# next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next token, current position in the batch)
next_batch_beam = []
# for each sentence
for batch_idx in range(batch_size):
# if we are done with this sentence
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
tf.reduce_max(next_scores[batch_idx]).numpy()
)
if done[batch_idx]:
assert (
len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert (
eos_token_ids is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# next sentence beam content
next_sent_beam = []
# next tokens for this sentence
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
# get beam and token IDs
beam_id = idx // vocab_size
token_id = idx % vocab_size
# add to generated hypotheses if end of sentence or last iteration
if eos_token_ids is not None and token_id.numpy() in eos_token_ids:
generated_hyps[batch_idx].add(
tf.identity(input_ids[batch_idx * num_beams + beam_id, :cur_len]), score.numpy()
)
else:
# add next predicted token if it is not eos_token
next_sent_beam.append((score, token_id, batch_idx * num_beams + beam_id))
# the beam for next step is full
if len(next_sent_beam) == num_beams:
break
# update next beam content
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32)
beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)
# re-order batch
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
# re-order internal states
if past:
past = self._reorder_cache(past, beam_idx)
# update current length
cur_len = cur_len + 1
# stop when we are done with each sentence
if all(done):
break
for batch_idx in range(batch_size):
# Add all open beam hypothesis to generated_hyps
if not done[batch_idx]:
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
# get beam and token IDs
beam_id = idx // vocab_size
token_id = idx % vocab_size
generated_hyps[batch_idx].add(
tf.identity(input_ids[batch_idx * num_beams + beam_id, :cur_len]), score.numpy()
)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
# select the best hypotheses
sent_lengths_list = []
best = []
# retrieve best hypotheses
for i, hypotheses in enumerate(generated_hyps):
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
for j in range(output_num_return_sequences_per_batch):
best_hyp = sorted_hyps.pop()[1]
sent_lengths_list.append(len(best_hyp))
best.append(best_hyp)
assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(output_batch_size, len(best))
sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32)
# shorter batches are filled with pad_token
if tf.reduce_min(sent_lengths).numpy() != tf.reduce_max(sent_lengths).numpy():
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
sent_max_len = min(tf.reduce_max(sent_lengths).numpy() + 1, max_length)
decoded_list = []
# fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best):
padding = tf.ones((sent_max_len - shape_list(hypo)[0],), dtype=tf.int32) * pad_token_id
decoded_hypo = tf.concat([hypo, padding], axis=0)
if sent_lengths[i] < max_length:
decoded_hypo = tf.where(tf.range(max_length) == sent_lengths[i], eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32), decoded_hypo)
decoded_list.append(decoded_hypo)
decoded = tf.stack(decoded_list)
else:
# none of the hypotheses have an eos_token
assert (len(hypo) == max_length for hypo in best)
decoded = tf.stack(best)
return decoded
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[i], 0)) for i in beam_idx]
# TODO: check whether it is an error that TF past.shape != Torch past.shape
reordered_layer_past = tf.concat(reordered_layer_past, axis=0)
# check that shape matches
assert shape_list(reordered_layer_past) == shape_list(layer_past)
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
# create logit penalties for already seen input_ids
token_penalties = np.ones(shape_list(logits))
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
for i, prev_input_id in enumerate(prev_input_ids):
logit_penalized = logits[i].numpy()[prev_input_id]
# if previous logit score is < 0 then multiply repetition penalty else divide
logit_penalized[logit_penalized < 0] = repetition_penalty
logit_penalized[logit_penalized > 0] = 1 / repetition_penalty
np.put(token_penalties[i], prev_input_id, logit_penalized)
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
@ -811,6 +1030,56 @@ def set_tensor_by_indices_to_value(tensor, indices, value):
return tf.where(indices, value_tensor, tensor)
class BeamHypotheses(object):
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len=None):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
class TFConv1D(tf.keras.layers.Layer):
def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
""" TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
@ -849,7 +1118,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
def build(self, input_shape):
"""Build shared word embedding layer
"""Build shared token embedding layer
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""

View File

@ -381,7 +381,6 @@ class TFModelTesterMixin:
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
for model_class in self.all_generative_model_classes:
# TODO (PVP): add beam search tests when beam search is implemented
model = model_class(config)
if config.bos_token_id is None:
@ -389,15 +388,34 @@ class TFModelTesterMixin:
model.generate(max_length=5)
# batch_size = 1
self._check_generated_tokens(model.generate(input_ids))
# batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(input_ids, num_beams=3))
else:
# batch_size = 1
self._check_generated_tokens(model.generate(max_length=5))
# batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
with self.assertRaises(AssertionError):
# generating multiple sequences when greedy no beam generation
# is not allowed as it would always generate the same sequences
model.generate(input_ids, do_sample=False, num_return_sequences=2)
with self.assertRaises(AssertionError):
# generating more sequences than having beams leads is not possible
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
# batch_size > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
# batch_size > 1, greedy
self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3))
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
# batch_size > 1, num_beams > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
# batch_size > 1, num_beams > 1, greedy
self._check_generated_tokens(
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
)
def _check_generated_tokens(self, output_ids):
for token_id in output_ids[0].numpy().tolist():