diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 43abdd9499..bb1856308a 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -142,7 +142,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): # # initialize all new embeddings (in particular added tokens) # self._init_weights(new_embeddings) - # # Copy word embeddings from the previous weights + # # Copy token embeddings from the previous weights # num_tokens_to_copy = min(old_num_tokens, new_num_tokens) # new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] @@ -557,6 +557,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): else: assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)." + if do_sample is False: + if num_beams == 1: + # no_beam_search greedy generation conditions + assert ( + num_return_sequences == 1 + ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" + + else: + # beam_search greedy generation conditions + assert ( + num_beams >= num_return_sequences + ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" + if pad_token_id is None and eos_token_ids is not None: logger.warning( "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0]) @@ -567,7 +580,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): cur_len = shape_list(input_ids)[1] vocab_size = self.config.vocab_size - if num_return_sequences != 1: + if num_return_sequences != 1 and do_sample: # Expand input to num return sequences input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_return_sequences, cur_len)) effective_batch_size = batch_size * num_return_sequences @@ -588,6 +601,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): pad_token_id, eos_token_ids, effective_batch_size, + num_return_sequences, length_penalty, num_beams, vocab_size, @@ -627,19 +641,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): All returned sequence are generated independantly. """ - def _create_next_token_logits_penalties(input_ids, logits): - # create logit penalties for already seen input_ids - token_penalties = np.ones(shape_list(logits)) - prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()] - for i, prev_input_id in enumerate(prev_input_ids): - logit_penalized = logits[i].numpy()[prev_input_id] - # if previous logit score is < 0 then multiply repetition penalty else divide - logit_penalized[logit_penalized < 0] = repetition_penalty - logit_penalized[logit_penalized > 0] = 1 / repetition_penalty - np.put(token_penalties[i], prev_input_id, logit_penalized) - return tf.convert_to_tensor(token_penalties, dtype=tf.float32) - - # current position / max lengths / length of generated sentences / unfinished sentences + # length of generated sentences / unfinished sentences unfinished_sents = tf.ones_like(input_ids[:, 0]) sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length @@ -656,7 +658,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) if repetition_penalty != 1.0: - next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits) + next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty) next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) if do_sample: @@ -738,11 +740,228 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): pad_token_id, eos_token_ids, batch_size, + num_return_sequences, length_penalty, num_beams, vocab_size, ): - pass + """ Generate sequences for each example with beam search. + """ + + # Expand input to num beams + input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_beams, cur_len)) + input_ids = tf.reshape(input_ids, (batch_size * num_beams, cur_len)) # (batch_size * num_beams, cur_len) + + # generated hypotheses + generated_hyps = [ + BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) + ] + + # scores for each sentence in the beam + beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32) + beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9 + beam_scores = tf.reshape(tf.concat([beam_scores_begin, beam_scores_end], -1), (batch_size * num_beams,)) + + # cache compute states + past = None + + # done sentences + done = [False for _ in range(batch_size)] + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation(input_ids, past=past) + outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) + next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) + + # if model has past, then set the past variable to speed up decoding + if self._do_output_past(outputs): + past = outputs[1] + + # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty) + next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) + + if do_sample: + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + # Top-p/top-k filtering + next_token_logits = tf_top_k_top_p_filtering( + next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 + ) # (batch_size * num_beams, vocab_size) + # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) + next_tokens = tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=2) # (batch_size * num_beams, vocab_size) + # Compute next scores + scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) + _scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2) + next_scores = _scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, 2)) # (batch_size * num_beams, 2) + # Match shape of greedy beam search + next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams) + next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams) + else: + # do greedy beam search + scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) + assert shape_list(scores) == [batch_size * num_beams, vocab_size] + # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) + next_scores = scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, vocab_size)) # (batch_size * num_beams, vocab_size) + + # re-organize to group the beam together (we are keeping top hypothesis accross beams) + next_scores = tf.reshape(next_scores, (batch_size, num_beams * vocab_size)) # (batch_size, num_beams * vocab_size) + next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True) + + assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams] + + # next batch beam content + # list of (batch_size * num_beams) tuple(next hypothesis score, next token, current position in the batch) + next_batch_beam = [] + + # for each sentence + for batch_idx in range(batch_size): + + # if we are done with this sentence + done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( + tf.reduce_max(next_scores[batch_idx]).numpy() + ) + if done[batch_idx]: + assert ( + len(generated_hyps[batch_idx]) >= num_beams + ), "Batch can only be done if at least {} beams have been generated".format(num_beams) + assert ( + eos_token_ids is not None and pad_token_id is not None + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" + next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch + continue + + # next sentence beam content + next_sent_beam = [] + + # next tokens for this sentence + for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]): + + # get beam and token IDs + beam_id = idx // vocab_size + token_id = idx % vocab_size + + # add to generated hypotheses if end of sentence or last iteration + if eos_token_ids is not None and token_id.numpy() in eos_token_ids: + generated_hyps[batch_idx].add( + tf.identity(input_ids[batch_idx * num_beams + beam_id, :cur_len]), score.numpy() + ) + else: + # add next predicted token if it is not eos_token + next_sent_beam.append((score, token_id, batch_idx * num_beams + beam_id)) + + # the beam for next step is full + if len(next_sent_beam) == num_beams: + break + + # update next beam content + assert len(next_sent_beam) == num_beams, "Beam should always be full" + next_batch_beam.extend(next_sent_beam) + assert len(next_batch_beam) == num_beams * (batch_idx + 1) + + # sanity check / prepare next batch + assert len(next_batch_beam) == batch_size * num_beams + beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32) + beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32) + beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32) + + # re-order batch + input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx]) + input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1) + + # re-order internal states + if past: + past = self._reorder_cache(past, beam_idx) + + # update current length + cur_len = cur_len + 1 + + # stop when we are done with each sentence + if all(done): + break + + for batch_idx in range(batch_size): + # Add all open beam hypothesis to generated_hyps + if not done[batch_idx]: + for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]): + + # get beam and token IDs + beam_id = idx // vocab_size + token_id = idx % vocab_size + generated_hyps[batch_idx].add( + tf.identity(input_ids[batch_idx * num_beams + beam_id, :cur_len]), score.numpy() + ) + + # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch + output_batch_size = batch_size if do_sample else batch_size * num_return_sequences + output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences + + # select the best hypotheses + sent_lengths_list = [] + best = [] + + # retrieve best hypotheses + for i, hypotheses in enumerate(generated_hyps): + sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0]) + for j in range(output_num_return_sequences_per_batch): + best_hyp = sorted_hyps.pop()[1] + sent_lengths_list.append(len(best_hyp)) + best.append(best_hyp) + assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(output_batch_size, len(best)) + + sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32) + + # shorter batches are filled with pad_token + if tf.reduce_min(sent_lengths).numpy() != tf.reduce_max(sent_lengths).numpy(): + assert pad_token_id is not None, "`Pad_token_id` has to be defined" + sent_max_len = min(tf.reduce_max(sent_lengths).numpy() + 1, max_length) + decoded_list = [] + + # fill with hypothesis and eos_token_id if necessary + for i, hypo in enumerate(best): + padding = tf.ones((sent_max_len - shape_list(hypo)[0],), dtype=tf.int32) * pad_token_id + decoded_hypo = tf.concat([hypo, padding], axis=0) + + if sent_lengths[i] < max_length: + decoded_hypo = tf.where(tf.range(max_length) == sent_lengths[i], eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32), decoded_hypo) + decoded_list.append(decoded_hypo) + decoded = tf.stack(decoded_list) + else: + # none of the hypotheses have an eos_token + assert (len(hypo) == max_length for hypo in best) + decoded = tf.stack(best) + + return decoded + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = [] + for layer_past in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` and `mems` is at 2nd position + reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[i], 0)) for i in beam_idx] + # TODO: check whether it is an error that TF past.shape != Torch past.shape + reordered_layer_past = tf.concat(reordered_layer_past, axis=0) + # check that shape matches + assert shape_list(reordered_layer_past) == shape_list(layer_past) + reordered_past.append(reordered_layer_past) + past = tuple(reordered_past) + return past + + +def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): + # create logit penalties for already seen input_ids + token_penalties = np.ones(shape_list(logits)) + prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()] + for i, prev_input_id in enumerate(prev_input_ids): + logit_penalized = logits[i].numpy()[prev_input_id] + # if previous logit score is < 0 then multiply repetition penalty else divide + logit_penalized[logit_penalized < 0] = repetition_penalty + logit_penalized[logit_penalized > 0] = 1 / repetition_penalty + np.put(token_penalties[i], prev_input_id, logit_penalized) + return tf.convert_to_tensor(token_penalties, dtype=tf.float32) def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): @@ -811,6 +1030,56 @@ def set_tensor_by_indices_to_value(tensor, indices, value): return tf.where(indices, value_tensor, tensor) +class BeamHypotheses(object): + def __init__(self, num_beams, max_length, length_penalty, early_stopping): + """ + Initialize n-best list of hypotheses. + """ + self.max_length = max_length - 1 # ignoring bos_token + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.num_beams = num_beams + self.beams = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.beams) + + def add(self, hyp, sum_logprobs): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / len(hyp) ** self.length_penalty + if len(self) < self.num_beams or score > self.worst_score: + self.beams.append((score, hyp)) + if len(self) > self.num_beams: + sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) + del self.beams[sorted_scores[0][1]] + self.worst_score = sorted_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs, cur_len=None): + """ + If there are enough hypotheses and that none of the hypotheses being generated + can become better than the worst one in the heap, then we are done with this sentence. + """ + + if len(self) < self.num_beams: + return False + elif self.early_stopping: + return True + else: + if cur_len is None: + cur_len = self.max_length + cur_score = best_sum_logprobs / cur_len ** self.length_penalty + ret = self.worst_score >= cur_score + return ret + + class TFConv1D(tf.keras.layers.Layer): def __init__(self, nf, nx, initializer_range=0.02, **kwargs): """ TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) @@ -849,7 +1118,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer): self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range def build(self, input_shape): - """Build shared word embedding layer + """Build shared token embedding layer Shared weights logic adapted from https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 """ diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 8cd53dfe19..a6d2e8e32f 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -381,7 +381,6 @@ class TFModelTesterMixin: ) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed. for model_class in self.all_generative_model_classes: - # TODO (PVP): add beam search tests when beam search is implemented model = model_class(config) if config.bos_token_id is None: @@ -389,15 +388,34 @@ class TFModelTesterMixin: model.generate(max_length=5) # batch_size = 1 self._check_generated_tokens(model.generate(input_ids)) + # batch_size = 1, num_beams > 1 + self._check_generated_tokens(model.generate(input_ids, num_beams=3)) else: # batch_size = 1 self._check_generated_tokens(model.generate(max_length=5)) # batch_size = 1, num_beams > 1 + self._check_generated_tokens(model.generate(max_length=5, num_beams=3)) + + with self.assertRaises(AssertionError): + # generating multiple sequences when greedy no beam generation + # is not allowed as it would always generate the same sequences + model.generate(input_ids, do_sample=False, num_return_sequences=2) + + with self.assertRaises(AssertionError): + # generating more sequences than having beams leads is not possible + model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2) # batch_size > 1, sample self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3)) # batch_size > 1, greedy - self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3)) + self._check_generated_tokens(model.generate(input_ids, do_sample=False)) + + # batch_size > 1, num_beams > 1, sample + self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,)) + # batch_size > 1, num_beams > 1, greedy + self._check_generated_tokens( + model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3) + ) def _check_generated_tokens(self, output_ids): for token_id in output_ids[0].numpy().tolist():