From 02d0e0355c663921360dcd8160791d40081c61d2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 9 Dec 2020 15:00:37 +0100 Subject: [PATCH] Diverse beam search 2 (#9006) * diverse beam search * bug fixes * bug fixes * bug fix * separate out diverse_beam_search function * separate out diverse_beam_search function * bug fix * improve code quality * bug fix * bug fix * separate out diverse beam search scorer * code format * code format * code format * code format * add test * code format * documentation changes * code quality * add slow integration tests * more general name * refactor into logits processor * add test * avoid too much copy paste * refactor * add to docs * fix-copies * bug fix * Revert "bug fix" This reverts commit c99eb5a8dc57a7b0d33a8ac06d8c6a32a7812ad4. * improve comment * implement sylvains feedback Co-authored-by: Ayush Jain Co-authored-by: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com> --- docs/source/internal/generation_utils.rst | 6 + src/transformers/__init__.py | 2 + src/transformers/configuration_utils.py | 8 + src/transformers/generation_beam_search.py | 38 ++- src/transformers/generation_logits_process.py | 76 ++++- src/transformers/generation_utils.py | 280 +++++++++++++++++- src/transformers/models/rag/modeling_rag.py | 12 + src/transformers/utils/dummy_pt_objects.py | 10 + tests/test_generation_logits_process.py | 28 ++ tests/test_generation_utils.py | 150 +++++++++- 10 files changed, 590 insertions(+), 20 deletions(-) diff --git a/docs/source/internal/generation_utils.rst b/docs/source/internal/generation_utils.rst index fc3c561120..6851cb7299 100644 --- a/docs/source/internal/generation_utils.rst +++ b/docs/source/internal/generation_utils.rst @@ -52,6 +52,12 @@ generation. .. autoclass:: transformers.NoBadWordsLogitsProcessor :members: __call__ +.. autoclass:: transformers.PrefixConstrainedLogitsProcessor + :members: __call__ + +.. autoclass:: transformers.HammingDiversityLogitsProcessor + :members: __call__ + BeamSearch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e99fae4349..617a1298bc 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -319,12 +319,14 @@ if is_torch_available(): ) from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_logits_process import ( + HammingDiversityLogitsProcessor, LogitsProcessor, LogitsProcessorList, LogitsWarper, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, + PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 04712a7b88..2825e7efa5 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -95,6 +95,12 @@ class PretrainedConfig(object): sentences are finished per batch or not. - **num_beams** (:obj:`int`, `optional`, defaults to 1) -- Number of beams for beam search that will be used by default in the :obj:`generate` method of the model. 1 means no beam search. + - **num_beam_groups** (:obj:`int`, `optional`, defaults to 1) -- Number of groups to divide :obj:`num_beams` + into in order to ensure diversity among different groups of beams that will be used by default in the + :obj:`generate` method of the model. 1 means no group beam search. + - **diversity_penalty** (:obj:`float`, `optional`, defaults to 0.0) -- Value to control diversity for group + beam search. that will be used by default in the :obj:`generate` method of the model. 0 means no diversity + penalty. The higher the penalty, the more diverse are the outputs. - **temperature** (:obj:`float`, `optional`, defaults to 1) -- The value used to module the next token probabilities that will be used by default in the :obj:`generate` method of the model. Must be strictly positive. @@ -185,6 +191,8 @@ class PretrainedConfig(object): self.do_sample = kwargs.pop("do_sample", False) self.early_stopping = kwargs.pop("early_stopping", False) self.num_beams = kwargs.pop("num_beams", 1) + self.num_beam_groups = kwargs.pop("num_beam_groups", 1) + self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0) self.temperature = kwargs.pop("temperature", 1.0) self.top_k = kwargs.pop("top_k", 50) self.top_p = kwargs.pop("top_p", 1.0) diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py index 135227895d..b04c93d567 100644 --- a/src/transformers/generation_beam_search.py +++ b/src/transformers/generation_beam_search.py @@ -122,9 +122,12 @@ class BeamSearchScorer(BeamScorer): Adapted in part from `Facebook's XLM beam search code `__. + Reference for the diverse beam search algorithm and implementation `Ashwin Kalyan's DBS implementation + `__ + Args: batch_size (:obj:`int`): - Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. + Batch Size of :obj:`input_ids` for which standard beam search decoding is run in parallel. max_length (:obj:`int`): The maximum length of the sequence to be generated. num_beams (:obj:`int`): @@ -141,6 +144,9 @@ class BeamSearchScorer(BeamScorer): num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): The number of beam hypotheses that shall be returned upon calling :meth:`~transformer.BeamSearchScorer.finalize`. + num_beam_groups (:obj:`int`): + Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of + beams. See `this paper `__ for more details. """ def __init__( @@ -152,6 +158,7 @@ class BeamSearchScorer(BeamScorer): length_penalty: Optional[float] = 1.0, do_early_stopping: Optional[bool] = False, num_beam_hyps_to_keep: Optional[int] = 1, + num_beam_groups: Optional[int] = 1, ): self.max_length = max_length self.num_beams = num_beams @@ -159,6 +166,8 @@ class BeamSearchScorer(BeamScorer): self.length_penalty = length_penalty self.do_early_stopping = do_early_stopping self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + self.num_beam_groups = num_beam_groups + self.group_size = self.num_beams // self.num_beam_groups self._is_init = False self._beam_hyps = [ @@ -177,6 +186,12 @@ class BeamSearchScorer(BeamScorer): f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." ) + if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): + raise ValueError( + f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` " + f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." + ) + @property def is_done(self) -> bool: return self._done.all() @@ -192,12 +207,12 @@ class BeamSearchScorer(BeamScorer): ) -> Tuple[torch.Tensor]: cur_len = input_ids.shape[-1] batch_size = len(self._beam_hyps) - assert batch_size == (input_ids.shape[0] // self.num_beams) + assert batch_size == (input_ids.shape[0] // self.group_size) device = input_ids.device - next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device) - next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device) - next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device) + next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: @@ -218,11 +233,11 @@ class BeamSearchScorer(BeamScorer): for beam_token_rank, (next_token, next_score, next_index) in enumerate( zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) ): - batch_beam_idx = batch_idx * self.num_beams + next_index + batch_beam_idx = batch_idx * self.group_size + next_index # add to generated hypotheses if end of sentence if (eos_token_id is not None) and (next_token.item() == eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added - is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size if is_beam_token_worse_than_top_num_beams: continue beam_hyp.add( @@ -237,12 +252,12 @@ class BeamSearchScorer(BeamScorer): beam_idx += 1 # once the beam for next step is full, don't add more tokens to it. - if beam_idx == self.num_beams: + if beam_idx == self.group_size: break - if beam_idx < self.num_beams: + if beam_idx < self.group_size: raise ValueError( - f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." ) # Check if we are done so that we can save a pad step if all(done) @@ -274,7 +289,8 @@ class BeamSearchScorer(BeamScorer): if self._done[batch_idx]: continue - # need to add best num_beams hypotheses to generated hyps + # all open beam hypotheses are added to the beam hypothesis + # beam hypothesis class automatically keeps the best beams for beam_id in range(self.num_beams): batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 0a841e8955..8ffab62386 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import math from abc import ABC from typing import Callable, Iterable, List @@ -37,6 +38,8 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`): Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax or scores for each vocabulary token after SoftMax. + kwargs: + Additional logits processor specific kwargs. Return: :obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores. @@ -75,9 +78,16 @@ class LogitsProcessorList(list): """ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor: for processor in self: - scores = processor(input_ids, scores) + function_args = inspect.signature(processor.__call__).parameters + if len(function_args) > 2: + assert all( + arg in kwargs for arg in list(function_args.keys())[2:] + ), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor." + scores = processor(input_ids, scores, **kwargs) + else: + scores = processor(input_ids, scores) return scores @@ -400,3 +410,65 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor): mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0 return scores + mask + + +class HammingDiversityLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` that enforces diverse beam search. Note that this logits processor is only + effective for `group_beam_search`. See `Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models + `__ for more details. + + Args: + diversity_penalty (:obj:`float`): + This value is subtracted from a beam's score if it generates a token same as any beam from other group at a + particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled. + num_beams (:obj:`int`): + Number of beams used for group beam search. See `this paper `__ for + more details. + num_beam_groups (:obj:`int`): + Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of + beams. See `this paper `__ for more details. + """ + + def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int): + if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0): + raise ValueError("`diversity_penalty` should be a float strictly larger than 0.") + self._diversity_penalty = diversity_penalty + if not isinstance(num_beams, int) or num_beams < 2: + raise ValueError("`num_beams` should be an integer strictly larger than 1.") + self._num_beams = num_beams + if not isinstance(num_beam_groups, int) or num_beam_groups < 2: + raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.") + if num_beam_groups > num_beams: + raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.") + if num_beam_groups > num_beams: + raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`") + self._num_sub_beams = num_beams // num_beam_groups + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + current_tokens: torch.LongTensor, + beam_group_idx: int, + ) -> torch.FloatTensor: + # hamming diversity: penalise using same token in current group which was used in previous groups at + # the same time step + batch_size = current_tokens.shape[0] // self._num_beams + group_start_idx = beam_group_idx * self._num_sub_beams + group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams) + group_size = group_end_idx - group_start_idx + vocab_size = scores.shape[-1] + + if group_start_idx == 0: + return scores + + for batch_idx in range(batch_size): + # predicted tokens of last time step of previous groups + previous_group_tokens = current_tokens[ + batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx + ] + token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device) + scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency + + return scores diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index eb6999e868..91cc97c95c 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -22,6 +22,7 @@ from torch.nn import functional as F from .file_utils import ModelOutput from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_logits_process import ( + HammingDiversityLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, @@ -261,6 +262,8 @@ class GenerationMixin: eos_token_id: int, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int, + num_beam_groups: int, + diversity_penalty: float, ) -> LogitsProcessorList: """ This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant @@ -275,11 +278,18 @@ class GenerationMixin: bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids min_length = min_length if min_length is not None else self.config.min_length eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty # instantiate processors list processors = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` + if diversity_penalty is not None and diversity_penalty > 0.0: + processors.append( + HammingDiversityLogitsProcessor( + diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups + ) + ) if repetition_penalty is not None and repetition_penalty != 1.0: processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: @@ -314,6 +324,8 @@ class GenerationMixin: num_return_sequences: Optional[int] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, + num_beam_groups: Optional[int] = None, + diversity_penalty: Optional[float] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, **model_kwargs ) -> torch.LongTensor: @@ -381,6 +393,13 @@ class GenerationMixin: use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + num_beam_groups (:obj:`int`, `optional`, defaults to 1): + Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of + beams. `this paper `__ for more details. + diversity_penalty (:obj:`float`, `optional`, defaults to 0.0): + This value is subtracted from a beam's score if it generates a token same as any beam from other group + at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is + enabled. prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID @@ -453,6 +472,7 @@ class GenerationMixin: # set init values num_beams = num_beams if num_beams is not None else self.config.num_beams + num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups max_length = max_length if max_length is not None else self.config.max_length do_sample = do_sample if do_sample is not None else self.config.do_sample num_return_sequences = ( @@ -491,10 +511,17 @@ class GenerationMixin: raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.") # determine generation mode - is_greedy_gen_mode = (num_beams == 1) and do_sample is False - is_sample_gen_mode = (num_beams == 1) and do_sample is True - is_beam_gen_mode = (num_beams > 1) and do_sample is False - is_beam_sample_gen_mode = (num_beams > 1) and do_sample is True + is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False + is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True + is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False + is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True + is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) + if num_beam_groups > num_beams: + raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") + if is_group_beam_gen_mode and do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." + ) # set model_kwargs model_kwargs["use_cache"] = use_cache @@ -508,6 +535,8 @@ class GenerationMixin: eos_token_id=eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, num_beams=num_beams, + num_beam_groups=num_beam_groups, + diversity_penalty=diversity_penalty, ) if is_greedy_gen_mode: @@ -619,6 +648,42 @@ class GenerationMixin: **model_kwargs, ) + elif is_group_beam_gen_mode: + batch_size = input_ids.shape[0] + + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + + if num_return_sequences > num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if num_beams % num_beam_groups != 0: + raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") + + diverse_beam_scorer = BeamSearchScorer( + batch_size=batch_size, + max_length=max_length, + num_beams=num_beams, + device=self.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + num_beam_hyps_to_keep=num_return_sequences, + num_beam_groups=num_beam_groups, + ) + # interleave with `num_beams` + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + return self.group_beam_search( + input_ids, + diverse_beam_scorer, + logits_processor=logits_processor, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + **model_kwargs, + ) + def greedy_search( self, input_ids: torch.LongTensor, @@ -1208,6 +1273,213 @@ class GenerationMixin: return decoded + def group_beam_search( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + **model_kwargs + ): + r""" + Generates sequences for models with a language modeling head using beam search decoding. + + Parameters: + + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty + :obj:`torch.LongTensor` of shape :obj:`(1,)`. + beam_scorer (:obj:`BeamScorer`): + An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are + constructed, stored and sorted during generation. For more information, the documentation of + :class:`~transformers.BeamScorer` should be read. + logits_processor (:obj:`LogitsProcessorList`, `optional`): + An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from + :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling + head applied at each generation step. + max_length (:obj:`int`, `optional`, defaults to 20): + The maximum length of the sequence to be generated. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + model_kwargs: + Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If + model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. + + Return: + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated + sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all + batches finished early due to the :obj:`eos_token_id`. + + Examples:: + + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... HammingDiversityLogitsProcessor, + ... BeamSearchScorer, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + + >>> # lets run diverse beam search using 6 beams + >>> num_beams = 6 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) + ... } + + >>> # instantiate beam scorer + >>> beam_scorer = BeamSearchScorer( + ... batch_size=1, + ... max_length=model.config.max_length, + ... num_beams=num_beams, + ... device=model.device, + ... num_beam_groups=3 + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList([ + ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), + ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ... ]) + + >>> outputs = model.group_beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) + + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + """ + + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + max_length = max_length if max_length is not None else self.config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + device = input_ids.device + + batch_beam_size, cur_len = input_ids.shape + + assert ( + num_beams * batch_size == batch_beam_size + ), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while cur_len < max_length: + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + outputs = self(**model_inputs, return_dict=True) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of current group only + next_token_logits = outputs.logits[batch_group_indices, -1, :] + + # adjust tokens for Bart, *e.g.* + next_token_logits = self.adjust_logits_during_generation( + next_token_logits, cur_len=cur_len, max_length=max_length + ) + + next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) + vocab_size = next_token_scores.shape[-1] + + next_token_scores = logits_processor( + group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as( + next_token_scores + ) + # reshape for beam search + + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = next_tokens // vocab_size + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size) + ) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + cur_len = cur_len + 1 + if beam_scorer.is_done: + break + + decoded = beam_scorer.finalize( + input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id + ) + + return decoded + def top_k_top_p_filtering( logits: torch.FloatTensor, diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 1192c70d9e..e8219c75fe 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1226,6 +1226,8 @@ class RagTokenForGeneration(RagPreTrainedModel): early_stopping=None, use_cache=None, num_beams=None, + num_beam_groups=None, + diversity_penalty=None, bos_token_id=None, pad_token_id=None, eos_token_id=None, @@ -1302,6 +1304,13 @@ class RagTokenForGeneration(RagPreTrainedModel): should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. num_beams (:obj:`int`, `optional`, defaults to 1): Number of beams for beam search. 1 means no beam search. + num_beam_groups (:obj:`int`, `optional`, defaults to 1): + Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of + beams. `this paper `__ for more details. + diversity_penalty (:obj:`float`, `optional`, defaults to 0.0): + This value is subtracted from a beam's score if it generates a token same as any beam from other group + at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is + enabled. num_return_sequences(:obj:`int`, `optional`, defaults to 1): The number of independently computed returned sequences for each element in the batch. Note that this is not the value we pass to the ``generator``'s `:func:`~transformers.PreTrainedModel.generate` @@ -1326,6 +1335,7 @@ class RagTokenForGeneration(RagPreTrainedModel): # set default parameters n_docs = n_docs if n_docs is not None else self.config.n_docs num_beams = num_beams if num_beams is not None else self.config.num_beams + num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups max_length = max_length if max_length is not None else self.config.max_length num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences @@ -1412,6 +1422,8 @@ class RagTokenForGeneration(RagPreTrainedModel): eos_token_id=eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, num_beams=num_beams, + num_beam_groups=num_beam_groups, + diversity_penalty=diversity_penalty, ) if num_beams == 1: diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 3dd8acffac..38df436701 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -118,6 +118,11 @@ class BeamSearchScorer: requires_pytorch(self) +class HammingDiversityLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class LogitsProcessor: def __init__(self, *args, **kwargs): requires_pytorch(self) @@ -148,6 +153,11 @@ class NoRepeatNGramLogitsProcessor: requires_pytorch(self) +class PrefixConstrainedLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class RepetitionPenaltyLogitsProcessor: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/tests/test_generation_logits_process.py b/tests/test_generation_logits_process.py index 7dd0d05517..1aa2941047 100644 --- a/tests/test_generation_logits_process.py +++ b/tests/test_generation_logits_process.py @@ -27,6 +27,7 @@ if is_torch_available(): import torch.nn.functional as F from transformers.generation_logits_process import ( + HammingDiversityLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, @@ -302,3 +303,30 @@ class LogitsProcessorTest(unittest.TestCase): self.assertListEqual( torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]] ) + + def test_hamming_diversity(self): + vocab_size = 4 + num_beams = 2 + num_beam_groups = 2 + + scores = self._get_uniform_logits(num_beams, vocab_size) + # batch_idx = 0 -> index batch_idx * num_beam_groups -> idx = 0 * 2 = 0 -> penalises tokens 1 + # batch_idx = 1 -> index batch_idx * num_beam_groups -> idx = 1 * 2 = 2 -> penalises tokens 1 + current_tokens = torch.tensor([0, 3, 1, 2], device=torch_device, dtype=torch.long) + + diversity_logits_processor = HammingDiversityLogitsProcessor( + diversity_penalty=1.0, num_beams=num_beams, num_beam_groups=num_beam_groups + ) + + processed_scores = diversity_logits_processor(None, scores, current_tokens, 1) + + self.assertTrue( + torch.allclose( + processed_scores[0], torch.tensor([-0.7500, 0.2500, 0.2500, 0.2500], device=torch_device), atol=1e-3 + ) + ) + self.assertTrue( + torch.allclose( + processed_scores[1], torch.tensor([0.2500, -0.7500, 0.2500, 0.2500], device=torch_device), atol=1e-3 + ) + ) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index dee7873fb9..ce0fe08fe0 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -17,15 +17,16 @@ import unittest from transformers import is_torch_available -from transformers.testing_utils import require_torch, torch_device +from transformers.testing_utils import require_torch, slow, torch_device if is_torch_available(): import torch - from transformers import top_k_top_p_filtering + from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering from transformers.generation_beam_search import BeamSearchScorer from transformers.generation_logits_process import ( + HammingDiversityLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, @@ -61,7 +62,7 @@ class GenerationTesterMixin: return config, input_ids, attention_mask, max_length @staticmethod - def _get_logits_processor_and_kwargs(input_length, eos_token_id): + def _get_logits_processor_and_kwargs(input_length, eos_token_id, diversity_penalty=None): process_kwargs = { "min_length": input_length + 1, "bad_words_ids": [[1, 0]], @@ -70,6 +71,13 @@ class GenerationTesterMixin: } logits_processor = LogitsProcessorList( ( + [ + HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2), + ] + if diversity_penalty is not None + else [] + ) + + ( [ MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id), ] @@ -115,6 +123,28 @@ class GenerationTesterMixin: ) return beam_kwargs, beam_scorer + @staticmethod + def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): + beam_kwargs = { + "early_stopping": False, + "length_penalty": 2.0, + "num_beams": 2, + "num_return_sequences": num_return_sequences, + "num_beam_groups": 2, # one beam per group + "diversity_penalty": 2.0, + } + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + max_length=max_length, + num_beams=beam_kwargs["num_beams"], + device=torch_device, + length_penalty=beam_kwargs["length_penalty"], + do_early_stopping=beam_kwargs["early_stopping"], + num_beam_hyps_to_keep=num_return_sequences, + num_beam_groups=beam_kwargs["num_beam_groups"], + ) + return beam_kwargs, beam_scorer + @staticmethod def _get_encoder_outputs(model, input_ids, attention_mask, num_interleave=1): encoder = model.get_encoder() @@ -408,6 +438,92 @@ class GenerationTesterMixin: self.assertIsNotNone(output_ids_generate) + def test_group_beam_search_generate(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0 + ) + + model = model_class(config).to(torch_device) + model.eval() + + # check `generate()` and `group_beam_search()` are equal + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + output_ids_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + max_length=max_length, + **beam_kwargs, + **logits_process_kwargs, + ) + + # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams + ) + kwargs["encoder_outputs"] = encoder_outputs + input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) + else: + attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) + input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + + with torch.no_grad(): + output_ids_group_beam_search = model.group_beam_search( + input_ids_clone, + beam_scorer, + max_length=max_length, + attention_mask=attention_mask_clone, + logits_processor=logits_processor, + **kwargs, + ) + self.assertListEqual(output_ids_generate.tolist(), output_ids_group_beam_search.tolist()) + + # check `generate()` and `group_beam_search()` are equal for `num_return_sequences` + num_return_sequences = 2 + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( + input_ids.shape[0], max_length, num_return_sequences=num_return_sequences + ) + + output_ids_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + max_length=max_length, + **beam_kwargs, + **logits_process_kwargs, + ) + # group_beam_search does not automatically interleave `batch_size` dim for `num_beams` + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams + ) + kwargs["encoder_outputs"] = encoder_outputs + input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) + else: + attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) + input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + + with torch.no_grad(): + output_ids_beam_search = model.group_beam_search( + input_ids_clone, + beam_scorer, + max_length=max_length, + attention_mask=attention_mask_clone, + logits_processor=logits_processor, + **kwargs, + ) + self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist()) + @require_torch class UtilsFunctionsTest(unittest.TestCase): @@ -512,3 +628,31 @@ class UtilsFunctionsTest(unittest.TestCase): self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) + + +@require_torch +class GenerationIntegrationTests(unittest.TestCase): + @slow + def test_diverse_beam_search(self): + article = """Justin Timberlake and Jessica Biel, welcome to parenthood. + The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People. + "Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports. + The couple announced the pregnancy in January, with an Instagram post. It is the first baby for both.""" + + bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device) + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + outputs = bart_model.generate( + input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0 + ) + + generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle name, as well as his father's first. It is the first baby for both of them.", + "Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.", + ], + )