added beam_search generation for tf 2.0
This commit is contained in:
parent
34de670dbe
commit
61fef6e957
|
@ -142,7 +142,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
# # initialize all new embeddings (in particular added tokens)
|
||||
# self._init_weights(new_embeddings)
|
||||
|
||||
# # Copy word embeddings from the previous weights
|
||||
# # Copy token embeddings from the previous weights
|
||||
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
||||
|
||||
|
@ -557,6 +557,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
else:
|
||||
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
||||
|
||||
if do_sample is False:
|
||||
if num_beams == 1:
|
||||
# no_beam_search greedy generation conditions
|
||||
assert (
|
||||
num_return_sequences == 1
|
||||
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
|
||||
|
||||
else:
|
||||
# beam_search greedy generation conditions
|
||||
assert (
|
||||
num_beams >= num_return_sequences
|
||||
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
|
||||
|
||||
if pad_token_id is None and eos_token_ids is not None:
|
||||
logger.warning(
|
||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
||||
|
@ -567,7 +580,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
cur_len = shape_list(input_ids)[1]
|
||||
vocab_size = self.config.vocab_size
|
||||
|
||||
if num_return_sequences != 1:
|
||||
if num_return_sequences != 1 and do_sample:
|
||||
# Expand input to num return sequences
|
||||
input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_return_sequences, cur_len))
|
||||
effective_batch_size = batch_size * num_return_sequences
|
||||
|
@ -588,6 +601,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
pad_token_id,
|
||||
eos_token_ids,
|
||||
effective_batch_size,
|
||||
num_return_sequences,
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
|
@ -627,19 +641,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
All returned sequence are generated independantly.
|
||||
"""
|
||||
|
||||
def _create_next_token_logits_penalties(input_ids, logits):
|
||||
# create logit penalties for already seen input_ids
|
||||
token_penalties = np.ones(shape_list(logits))
|
||||
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
|
||||
for i, prev_input_id in enumerate(prev_input_ids):
|
||||
logit_penalized = logits[i].numpy()[prev_input_id]
|
||||
# if previous logit score is < 0 then multiply repetition penalty else divide
|
||||
logit_penalized[logit_penalized < 0] = repetition_penalty
|
||||
logit_penalized[logit_penalized > 0] = 1 / repetition_penalty
|
||||
np.put(token_penalties[i], prev_input_id, logit_penalized)
|
||||
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
|
||||
|
||||
# current position / max lengths / length of generated sentences / unfinished sentences
|
||||
# length of generated sentences / unfinished sentences
|
||||
unfinished_sents = tf.ones_like(input_ids[:, 0])
|
||||
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
|
||||
|
||||
|
@ -656,7 +658,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
|
||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits)
|
||||
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty)
|
||||
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
|
||||
|
||||
if do_sample:
|
||||
|
@ -738,11 +740,228 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
pad_token_id,
|
||||
eos_token_ids,
|
||||
batch_size,
|
||||
num_return_sequences,
|
||||
length_penalty,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
):
|
||||
pass
|
||||
""" Generate sequences for each example with beam search.
|
||||
"""
|
||||
|
||||
# Expand input to num beams
|
||||
input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_beams, cur_len))
|
||||
input_ids = tf.reshape(input_ids, (batch_size * num_beams, cur_len)) # (batch_size * num_beams, cur_len)
|
||||
|
||||
# generated hypotheses
|
||||
generated_hyps = [
|
||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
# scores for each sentence in the beam
|
||||
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
|
||||
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
|
||||
beam_scores = tf.reshape(tf.concat([beam_scores_begin, beam_scores_end], -1), (batch_size * num_beams,))
|
||||
|
||||
# cache compute states
|
||||
past = None
|
||||
|
||||
# done sentences
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._do_output_past(outputs):
|
||||
past = outputs[1]
|
||||
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty)
|
||||
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
# Top-p/top-k filtering
|
||||
next_token_logits = tf_top_k_top_p_filtering(
|
||||
next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
|
||||
next_tokens = tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=2) # (batch_size * num_beams, vocab_size)
|
||||
# Compute next scores
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
_scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2)
|
||||
next_scores = _scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, 2)) # (batch_size * num_beams, 2)
|
||||
# Match shape of greedy beam search
|
||||
next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
|
||||
next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
|
||||
else:
|
||||
# do greedy beam search
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
|
||||
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
||||
next_scores = scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, vocab_size)) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||
next_scores = tf.reshape(next_scores, (batch_size, num_beams * vocab_size)) # (batch_size, num_beams * vocab_size)
|
||||
next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)
|
||||
|
||||
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
||||
|
||||
# next batch beam content
|
||||
# list of (batch_size * num_beams) tuple(next hypothesis score, next token, current position in the batch)
|
||||
next_batch_beam = []
|
||||
|
||||
# for each sentence
|
||||
for batch_idx in range(batch_size):
|
||||
|
||||
# if we are done with this sentence
|
||||
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
||||
tf.reduce_max(next_scores[batch_idx]).numpy()
|
||||
)
|
||||
if done[batch_idx]:
|
||||
assert (
|
||||
len(generated_hyps[batch_idx]) >= num_beams
|
||||
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
|
||||
assert (
|
||||
eos_token_ids is not None and pad_token_id is not None
|
||||
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
||||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||
continue
|
||||
|
||||
# next sentence beam content
|
||||
next_sent_beam = []
|
||||
|
||||
# next tokens for this sentence
|
||||
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
|
||||
|
||||
# get beam and token IDs
|
||||
beam_id = idx // vocab_size
|
||||
token_id = idx % vocab_size
|
||||
|
||||
# add to generated hypotheses if end of sentence or last iteration
|
||||
if eos_token_ids is not None and token_id.numpy() in eos_token_ids:
|
||||
generated_hyps[batch_idx].add(
|
||||
tf.identity(input_ids[batch_idx * num_beams + beam_id, :cur_len]), score.numpy()
|
||||
)
|
||||
else:
|
||||
# add next predicted token if it is not eos_token
|
||||
next_sent_beam.append((score, token_id, batch_idx * num_beams + beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
if len(next_sent_beam) == num_beams:
|
||||
break
|
||||
|
||||
# update next beam content
|
||||
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
|
||||
|
||||
# sanity check / prepare next batch
|
||||
assert len(next_batch_beam) == batch_size * num_beams
|
||||
beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32)
|
||||
beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
|
||||
beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)
|
||||
|
||||
# re-order batch
|
||||
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
|
||||
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
|
||||
|
||||
# re-order internal states
|
||||
if past:
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
# update current length
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# stop when we are done with each sentence
|
||||
if all(done):
|
||||
break
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
# Add all open beam hypothesis to generated_hyps
|
||||
if not done[batch_idx]:
|
||||
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
|
||||
|
||||
# get beam and token IDs
|
||||
beam_id = idx // vocab_size
|
||||
token_id = idx % vocab_size
|
||||
generated_hyps[batch_idx].add(
|
||||
tf.identity(input_ids[batch_idx * num_beams + beam_id, :cur_len]), score.numpy()
|
||||
)
|
||||
|
||||
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
|
||||
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
|
||||
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
|
||||
|
||||
# select the best hypotheses
|
||||
sent_lengths_list = []
|
||||
best = []
|
||||
|
||||
# retrieve best hypotheses
|
||||
for i, hypotheses in enumerate(generated_hyps):
|
||||
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
|
||||
for j in range(output_num_return_sequences_per_batch):
|
||||
best_hyp = sorted_hyps.pop()[1]
|
||||
sent_lengths_list.append(len(best_hyp))
|
||||
best.append(best_hyp)
|
||||
assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(output_batch_size, len(best))
|
||||
|
||||
sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32)
|
||||
|
||||
# shorter batches are filled with pad_token
|
||||
if tf.reduce_min(sent_lengths).numpy() != tf.reduce_max(sent_lengths).numpy():
|
||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
||||
sent_max_len = min(tf.reduce_max(sent_lengths).numpy() + 1, max_length)
|
||||
decoded_list = []
|
||||
|
||||
# fill with hypothesis and eos_token_id if necessary
|
||||
for i, hypo in enumerate(best):
|
||||
padding = tf.ones((sent_max_len - shape_list(hypo)[0],), dtype=tf.int32) * pad_token_id
|
||||
decoded_hypo = tf.concat([hypo, padding], axis=0)
|
||||
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded_hypo = tf.where(tf.range(max_length) == sent_lengths[i], eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32), decoded_hypo)
|
||||
decoded_list.append(decoded_hypo)
|
||||
decoded = tf.stack(decoded_list)
|
||||
else:
|
||||
# none of the hypotheses have an eos_token
|
||||
assert (len(hypo) == max_length for hypo in best)
|
||||
decoded = tf.stack(best)
|
||||
|
||||
return decoded
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = []
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` and `mems` is at 2nd position
|
||||
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[i], 0)) for i in beam_idx]
|
||||
# TODO: check whether it is an error that TF past.shape != Torch past.shape
|
||||
reordered_layer_past = tf.concat(reordered_layer_past, axis=0)
|
||||
# check that shape matches
|
||||
assert shape_list(reordered_layer_past) == shape_list(layer_past)
|
||||
reordered_past.append(reordered_layer_past)
|
||||
past = tuple(reordered_past)
|
||||
return past
|
||||
|
||||
|
||||
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
||||
# create logit penalties for already seen input_ids
|
||||
token_penalties = np.ones(shape_list(logits))
|
||||
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
|
||||
for i, prev_input_id in enumerate(prev_input_ids):
|
||||
logit_penalized = logits[i].numpy()[prev_input_id]
|
||||
# if previous logit score is < 0 then multiply repetition penalty else divide
|
||||
logit_penalized[logit_penalized < 0] = repetition_penalty
|
||||
logit_penalized[logit_penalized > 0] = 1 / repetition_penalty
|
||||
np.put(token_penalties[i], prev_input_id, logit_penalized)
|
||||
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
|
||||
|
||||
|
||||
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||
|
@ -811,6 +1030,56 @@ def set_tensor_by_indices_to_value(tensor, indices, value):
|
|||
return tf.where(indices, value_tensor, tensor)
|
||||
|
||||
|
||||
class BeamHypotheses(object):
|
||||
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.max_length = max_length - 1 # ignoring bos_token
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.num_beams = num_beams
|
||||
self.beams = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
"""
|
||||
return len(self.beams)
|
||||
|
||||
def add(self, hyp, sum_logprobs):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / len(hyp) ** self.length_penalty
|
||||
if len(self) < self.num_beams or score > self.worst_score:
|
||||
self.beams.append((score, hyp))
|
||||
if len(self) > self.num_beams:
|
||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
|
||||
del self.beams[sorted_scores[0][1]]
|
||||
self.worst_score = sorted_scores[1][0]
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs, cur_len=None):
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated
|
||||
can become better than the worst one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
|
||||
if len(self) < self.num_beams:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
if cur_len is None:
|
||||
cur_len = self.max_length
|
||||
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
||||
ret = self.worst_score >= cur_score
|
||||
return ret
|
||||
|
||||
|
||||
class TFConv1D(tf.keras.layers.Layer):
|
||||
def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
|
||||
""" TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
|
||||
|
@ -849,7 +1118,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
|
|||
self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
|
||||
|
||||
def build(self, input_shape):
|
||||
"""Build shared word embedding layer
|
||||
"""Build shared token embedding layer
|
||||
Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
|
|
|
@ -381,7 +381,6 @@ class TFModelTesterMixin:
|
|||
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# TODO (PVP): add beam search tests when beam search is implemented
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
|
@ -389,15 +388,34 @@ class TFModelTesterMixin:
|
|||
model.generate(max_length=5)
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(input_ids))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3))
|
||||
else:
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(max_length=5))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating multiple sequences when greedy no beam generation
|
||||
# is not allowed as it would always generate the same sequences
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=2)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# batch_size > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
||||
# batch_size > 1, greedy
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3))
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
|
||||
|
||||
# batch_size > 1, num_beams > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
|
||||
# batch_size > 1, num_beams > 1, greedy
|
||||
self._check_generated_tokens(
|
||||
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
|
||||
)
|
||||
|
||||
def _check_generated_tokens(self, output_ids):
|
||||
for token_id in output_ids[0].numpy().tolist():
|
||||
|
|
Loading…
Reference in New Issue