Diverse beam search 2 (#9006)
* diverse beam search
* bug fixes
* bug fixes
* bug fix
* separate out diverse_beam_search function
* separate out diverse_beam_search function
* bug fix
* improve code quality
* bug fix
* bug fix
* separate out diverse beam search scorer
* code format
* code format
* code format
* code format
* add test
* code format
* documentation changes
* code quality
* add slow integration tests
* more general name
* refactor into logits processor
* add test
* avoid too much copy paste
* refactor
* add to docs
* fix-copies
* bug fix
* Revert "bug fix"
This reverts commit c99eb5a8dc
.
* improve comment
* implement sylvains feedback
Co-authored-by: Ayush Jain <a.jain@sprinklr.com>
Co-authored-by: ayushtiku5 <40797286+ayushtiku5@users.noreply.github.com>
This commit is contained in:
parent
67ff1c314a
commit
02d0e0355c
|
@ -52,6 +52,12 @@ generation.
|
|||
.. autoclass:: transformers.NoBadWordsLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.PrefixConstrainedLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.HammingDiversityLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
BeamSearch
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -319,12 +319,14 @@ if is_torch_available():
|
|||
)
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||
from .generation_logits_process import (
|
||||
HammingDiversityLogitsProcessor,
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
LogitsWarper,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
|
|
|
@ -95,6 +95,12 @@ class PretrainedConfig(object):
|
|||
sentences are finished per batch or not.
|
||||
- **num_beams** (:obj:`int`, `optional`, defaults to 1) -- Number of beams for beam search that will be used by
|
||||
default in the :obj:`generate` method of the model. 1 means no beam search.
|
||||
- **num_beam_groups** (:obj:`int`, `optional`, defaults to 1) -- Number of groups to divide :obj:`num_beams`
|
||||
into in order to ensure diversity among different groups of beams that will be used by default in the
|
||||
:obj:`generate` method of the model. 1 means no group beam search.
|
||||
- **diversity_penalty** (:obj:`float`, `optional`, defaults to 0.0) -- Value to control diversity for group
|
||||
beam search. that will be used by default in the :obj:`generate` method of the model. 0 means no diversity
|
||||
penalty. The higher the penalty, the more diverse are the outputs.
|
||||
- **temperature** (:obj:`float`, `optional`, defaults to 1) -- The value used to module the next token
|
||||
probabilities that will be used by default in the :obj:`generate` method of the model. Must be strictly
|
||||
positive.
|
||||
|
@ -185,6 +191,8 @@ class PretrainedConfig(object):
|
|||
self.do_sample = kwargs.pop("do_sample", False)
|
||||
self.early_stopping = kwargs.pop("early_stopping", False)
|
||||
self.num_beams = kwargs.pop("num_beams", 1)
|
||||
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
||||
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
||||
self.temperature = kwargs.pop("temperature", 1.0)
|
||||
self.top_k = kwargs.pop("top_k", 50)
|
||||
self.top_p = kwargs.pop("top_p", 1.0)
|
||||
|
|
|
@ -122,9 +122,12 @@ class BeamSearchScorer(BeamScorer):
|
|||
Adapted in part from `Facebook's XLM beam search code
|
||||
<https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
|
||||
|
||||
Reference for the diverse beam search algorithm and implementation `Ashwin Kalyan's DBS implementation
|
||||
<https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua>`__
|
||||
|
||||
Args:
|
||||
batch_size (:obj:`int`):
|
||||
Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel.
|
||||
Batch Size of :obj:`input_ids` for which standard beam search decoding is run in parallel.
|
||||
max_length (:obj:`int`):
|
||||
The maximum length of the sequence to be generated.
|
||||
num_beams (:obj:`int`):
|
||||
|
@ -141,6 +144,9 @@ class BeamSearchScorer(BeamScorer):
|
|||
num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
|
||||
The number of beam hypotheses that shall be returned upon calling
|
||||
:meth:`~transformer.BeamSearchScorer.finalize`.
|
||||
num_beam_groups (:obj:`int`):
|
||||
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
|
||||
beams. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -152,6 +158,7 @@ class BeamSearchScorer(BeamScorer):
|
|||
length_penalty: Optional[float] = 1.0,
|
||||
do_early_stopping: Optional[bool] = False,
|
||||
num_beam_hyps_to_keep: Optional[int] = 1,
|
||||
num_beam_groups: Optional[int] = 1,
|
||||
):
|
||||
self.max_length = max_length
|
||||
self.num_beams = num_beams
|
||||
|
@ -159,6 +166,8 @@ class BeamSearchScorer(BeamScorer):
|
|||
self.length_penalty = length_penalty
|
||||
self.do_early_stopping = do_early_stopping
|
||||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
||||
self.num_beam_groups = num_beam_groups
|
||||
self.group_size = self.num_beams // self.num_beam_groups
|
||||
|
||||
self._is_init = False
|
||||
self._beam_hyps = [
|
||||
|
@ -177,6 +186,12 @@ class BeamSearchScorer(BeamScorer):
|
|||
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
|
||||
)
|
||||
|
||||
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
|
||||
raise ValueError(
|
||||
f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` "
|
||||
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
||||
)
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
return self._done.all()
|
||||
|
@ -192,12 +207,12 @@ class BeamSearchScorer(BeamScorer):
|
|||
) -> Tuple[torch.Tensor]:
|
||||
cur_len = input_ids.shape[-1]
|
||||
batch_size = len(self._beam_hyps)
|
||||
assert batch_size == (input_ids.shape[0] // self.num_beams)
|
||||
assert batch_size == (input_ids.shape[0] // self.group_size)
|
||||
|
||||
device = input_ids.device
|
||||
next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
|
||||
next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
|
||||
next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
|
||||
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
|
||||
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
||||
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
||||
|
||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||
if self._done[batch_idx]:
|
||||
|
@ -218,11 +233,11 @@ class BeamSearchScorer(BeamScorer):
|
|||
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
|
||||
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
|
||||
):
|
||||
batch_beam_idx = batch_idx * self.num_beams + next_index
|
||||
batch_beam_idx = batch_idx * self.group_size + next_index
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
continue
|
||||
beam_hyp.add(
|
||||
|
@ -237,12 +252,12 @@ class BeamSearchScorer(BeamScorer):
|
|||
beam_idx += 1
|
||||
|
||||
# once the beam for next step is full, don't add more tokens to it.
|
||||
if beam_idx == self.num_beams:
|
||||
if beam_idx == self.group_size:
|
||||
break
|
||||
|
||||
if beam_idx < self.num_beams:
|
||||
if beam_idx < self.group_size:
|
||||
raise ValueError(
|
||||
f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
||||
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
||||
)
|
||||
|
||||
# Check if we are done so that we can save a pad step if all(done)
|
||||
|
@ -274,7 +289,8 @@ class BeamSearchScorer(BeamScorer):
|
|||
if self._done[batch_idx]:
|
||||
continue
|
||||
|
||||
# need to add best num_beams hypotheses to generated hyps
|
||||
# all open beam hypotheses are added to the beam hypothesis
|
||||
# beam hypothesis class automatically keeps the best beams
|
||||
for beam_id in range(self.num_beams):
|
||||
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from abc import ABC
|
||||
from typing import Callable, Iterable, List
|
||||
|
@ -37,6 +38,8 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
|||
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
|
||||
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
|
||||
or scores for each vocabulary token after SoftMax.
|
||||
kwargs:
|
||||
Additional logits processor specific kwargs.
|
||||
|
||||
Return:
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
|
||||
|
@ -75,9 +78,16 @@ class LogitsProcessorList(list):
|
|||
"""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
for processor in self:
|
||||
scores = processor(input_ids, scores)
|
||||
function_args = inspect.signature(processor.__call__).parameters
|
||||
if len(function_args) > 2:
|
||||
assert all(
|
||||
arg in kwargs for arg in list(function_args.keys())[2:]
|
||||
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
|
||||
scores = processor(input_ids, scores, **kwargs)
|
||||
else:
|
||||
scores = processor(input_ids, scores)
|
||||
return scores
|
||||
|
||||
|
||||
|
@ -400,3 +410,65 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
|||
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
|
||||
|
||||
return scores + mask
|
||||
|
||||
|
||||
class HammingDiversityLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
:class:`transformers.LogitsProcessor` that enforces diverse beam search. Note that this logits processor is only
|
||||
effective for `group_beam_search`. See `Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models
|
||||
<https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
|
||||
|
||||
Args:
|
||||
diversity_penalty (:obj:`float`):
|
||||
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
|
||||
particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is enabled.
|
||||
num_beams (:obj:`int`):
|
||||
Number of beams used for group beam search. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for
|
||||
more details.
|
||||
num_beam_groups (:obj:`int`):
|
||||
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
|
||||
beams. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
|
||||
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
|
||||
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
|
||||
self._diversity_penalty = diversity_penalty
|
||||
if not isinstance(num_beams, int) or num_beams < 2:
|
||||
raise ValueError("`num_beams` should be an integer strictly larger than 1.")
|
||||
self._num_beams = num_beams
|
||||
if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
|
||||
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
|
||||
if num_beam_groups > num_beams:
|
||||
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
|
||||
if num_beam_groups > num_beams:
|
||||
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`")
|
||||
self._num_sub_beams = num_beams // num_beam_groups
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
current_tokens: torch.LongTensor,
|
||||
beam_group_idx: int,
|
||||
) -> torch.FloatTensor:
|
||||
# hamming diversity: penalise using same token in current group which was used in previous groups at
|
||||
# the same time step
|
||||
batch_size = current_tokens.shape[0] // self._num_beams
|
||||
group_start_idx = beam_group_idx * self._num_sub_beams
|
||||
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
|
||||
group_size = group_end_idx - group_start_idx
|
||||
vocab_size = scores.shape[-1]
|
||||
|
||||
if group_start_idx == 0:
|
||||
return scores
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
# predicted tokens of last time step of previous groups
|
||||
previous_group_tokens = current_tokens[
|
||||
batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx
|
||||
]
|
||||
token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)
|
||||
scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency
|
||||
|
||||
return scores
|
||||
|
|
|
@ -22,6 +22,7 @@ from torch.nn import functional as F
|
|||
from .file_utils import ModelOutput
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||
from .generation_logits_process import (
|
||||
HammingDiversityLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
|
@ -261,6 +262,8 @@ class GenerationMixin:
|
|||
eos_token_id: int,
|
||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
||||
num_beams: int,
|
||||
num_beam_groups: int,
|
||||
diversity_penalty: float,
|
||||
) -> LogitsProcessorList:
|
||||
"""
|
||||
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
|
||||
|
@ -275,11 +278,18 @@ class GenerationMixin:
|
|||
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
|
||||
# instantiate processors list
|
||||
processors = LogitsProcessorList()
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if diversity_penalty is not None and diversity_penalty > 0.0:
|
||||
processors.append(
|
||||
HammingDiversityLogitsProcessor(
|
||||
diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups
|
||||
)
|
||||
)
|
||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
|
||||
|
@ -314,6 +324,8 @@ class GenerationMixin:
|
|||
num_return_sequences: Optional[int] = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
num_beam_groups: Optional[int] = None,
|
||||
diversity_penalty: Optional[float] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
**model_kwargs
|
||||
) -> torch.LongTensor:
|
||||
|
@ -381,6 +393,13 @@ class GenerationMixin:
|
|||
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||
speed up decoding.
|
||||
num_beam_groups (:obj:`int`, `optional`, defaults to 1):
|
||||
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
|
||||
beams. `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
|
||||
diversity_penalty (:obj:`float`, `optional`, defaults to 0.0):
|
||||
This value is subtracted from a beam's score if it generates a token same as any beam from other group
|
||||
at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is
|
||||
enabled.
|
||||
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
|
||||
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
||||
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
|
||||
|
@ -453,6 +472,7 @@ class GenerationMixin:
|
|||
|
||||
# set init values
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
num_return_sequences = (
|
||||
|
@ -491,10 +511,17 @@ class GenerationMixin:
|
|||
raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")
|
||||
|
||||
# determine generation mode
|
||||
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
||||
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
||||
is_beam_gen_mode = (num_beams > 1) and do_sample is False
|
||||
is_beam_sample_gen_mode = (num_beams > 1) and do_sample is True
|
||||
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False
|
||||
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True
|
||||
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False
|
||||
is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True
|
||||
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1)
|
||||
if num_beam_groups > num_beams:
|
||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||
if is_group_beam_gen_mode and do_sample is True:
|
||||
raise ValueError(
|
||||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
||||
)
|
||||
|
||||
# set model_kwargs
|
||||
model_kwargs["use_cache"] = use_cache
|
||||
|
@ -508,6 +535,8 @@ class GenerationMixin:
|
|||
eos_token_id=eos_token_id,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
num_beams=num_beams,
|
||||
num_beam_groups=num_beam_groups,
|
||||
diversity_penalty=diversity_penalty,
|
||||
)
|
||||
|
||||
if is_greedy_gen_mode:
|
||||
|
@ -619,6 +648,42 @@ class GenerationMixin:
|
|||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif is_group_beam_gen_mode:
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
|
||||
if num_return_sequences > num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if num_beams % num_beam_groups != 0:
|
||||
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
|
||||
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
do_early_stopping=early_stopping,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
num_beam_groups=num_beam_groups,
|
||||
)
|
||||
# interleave with `num_beams`
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
|
||||
)
|
||||
return self.group_beam_search(
|
||||
input_ids,
|
||||
diverse_beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def greedy_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
|
@ -1208,6 +1273,213 @@ class GenerationMixin:
|
|||
|
||||
return decoded
|
||||
|
||||
def group_beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
beam_scorer: BeamScorer,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
**model_kwargs
|
||||
):
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head using beam search decoding.
|
||||
|
||||
Parameters:
|
||||
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
|
||||
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
||||
beam_scorer (:obj:`BeamScorer`):
|
||||
An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are
|
||||
constructed, stored and sorted during generation. For more information, the documentation of
|
||||
:class:`~transformers.BeamScorer` should be read.
|
||||
logits_processor (:obj:`LogitsProcessorList`, `optional`):
|
||||
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
||||
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
||||
head applied at each generation step.
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
The id of the `padding` token.
|
||||
eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the `end-of-sequence` token.
|
||||
model_kwargs:
|
||||
Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If
|
||||
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
||||
|
||||
Return:
|
||||
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
|
||||
batches finished early due to the :obj:`eos_token_id`.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import (
|
||||
... AutoTokenizer,
|
||||
... AutoModelForSeq2SeqLM,
|
||||
... LogitsProcessorList,
|
||||
... MinLengthLogitsProcessor,
|
||||
... HammingDiversityLogitsProcessor,
|
||||
... BeamSearchScorer,
|
||||
... )
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
|
||||
>>> encoder_input_str = "translate English to German: How old are you?"
|
||||
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
|
||||
|
||||
>>> # lets run diverse beam search using 6 beams
|
||||
>>> num_beams = 6
|
||||
>>> # define decoder start token ids
|
||||
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
|
||||
>>> input_ids = input_ids * model.config.decoder_start_token_id
|
||||
|
||||
>>> # add encoder_outputs to model keyword arguments
|
||||
>>> model_kwargs = {
|
||||
... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
|
||||
... }
|
||||
|
||||
>>> # instantiate beam scorer
|
||||
>>> beam_scorer = BeamSearchScorer(
|
||||
... batch_size=1,
|
||||
... max_length=model.config.max_length,
|
||||
... num_beams=num_beams,
|
||||
... device=model.device,
|
||||
... num_beam_groups=3
|
||||
... )
|
||||
|
||||
>>> # instantiate logits processors
|
||||
>>> logits_processor = LogitsProcessorList([
|
||||
... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),
|
||||
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
|
||||
... ])
|
||||
|
||||
>>> outputs = model.group_beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
"""
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
|
||||
batch_size = len(beam_scorer._beam_hyps)
|
||||
num_beams = beam_scorer.num_beams
|
||||
num_beam_groups = beam_scorer.num_beam_groups
|
||||
num_sub_beams = num_beams // num_beam_groups
|
||||
device = input_ids.device
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
|
||||
assert (
|
||||
num_beams * batch_size == batch_beam_size
|
||||
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
|
||||
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
||||
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
||||
# the same group don't produce same tokens everytime.
|
||||
beam_scores[:, ::num_sub_beams] = 0
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
while cur_len < max_length:
|
||||
# predicted tokens in cur_len step
|
||||
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
||||
|
||||
# indices which will form the beams in the next time step
|
||||
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
||||
|
||||
# do one decoder step on all beams of all sentences in batch
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
|
||||
for beam_group_idx in range(num_beam_groups):
|
||||
group_start_idx = beam_group_idx * num_sub_beams
|
||||
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
||||
group_size = group_end_idx - group_start_idx
|
||||
|
||||
# indices of beams of current group among all sentences in batch
|
||||
batch_group_indices = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_group_indices.extend(
|
||||
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
||||
)
|
||||
group_input_ids = input_ids[batch_group_indices]
|
||||
|
||||
# select outputs of beams of current group only
|
||||
next_token_logits = outputs.logits[batch_group_indices, -1, :]
|
||||
|
||||
# adjust tokens for Bart, *e.g.*
|
||||
next_token_logits = self.adjust_logits_during_generation(
|
||||
next_token_logits, cur_len=cur_len, max_length=max_length
|
||||
)
|
||||
|
||||
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
|
||||
vocab_size = next_token_scores.shape[-1]
|
||||
|
||||
next_token_scores = logits_processor(
|
||||
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
||||
)
|
||||
next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(
|
||||
next_token_scores
|
||||
)
|
||||
# reshape for beam search
|
||||
|
||||
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
||||
|
||||
next_token_scores, next_tokens = torch.topk(
|
||||
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
||||
)
|
||||
|
||||
next_indices = next_tokens // vocab_size
|
||||
next_tokens = next_tokens % vocab_size
|
||||
|
||||
# stateless
|
||||
beam_outputs = beam_scorer.process(
|
||||
group_input_ids,
|
||||
next_token_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
beam_idx = beam_outputs["next_beam_indices"]
|
||||
|
||||
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
||||
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
||||
|
||||
# (beam_idx // group_size) -> batch_idx
|
||||
# (beam_idx % group_size) -> offset of idx inside the group
|
||||
reordering_indices[batch_group_indices] = (
|
||||
num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size)
|
||||
)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
if model_kwargs["past"] is not None:
|
||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)
|
||||
|
||||
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
||||
cur_len = cur_len + 1
|
||||
if beam_scorer.is_done:
|
||||
break
|
||||
|
||||
decoded = beam_scorer.finalize(
|
||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
||||
)
|
||||
|
||||
return decoded
|
||||
|
||||
|
||||
def top_k_top_p_filtering(
|
||||
logits: torch.FloatTensor,
|
||||
|
|
|
@ -1226,6 +1226,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||
early_stopping=None,
|
||||
use_cache=None,
|
||||
num_beams=None,
|
||||
num_beam_groups=None,
|
||||
diversity_penalty=None,
|
||||
bos_token_id=None,
|
||||
pad_token_id=None,
|
||||
eos_token_id=None,
|
||||
|
@ -1302,6 +1304,13 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
|
||||
num_beams (:obj:`int`, `optional`, defaults to 1):
|
||||
Number of beams for beam search. 1 means no beam search.
|
||||
num_beam_groups (:obj:`int`, `optional`, defaults to 1):
|
||||
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of
|
||||
beams. `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
|
||||
diversity_penalty (:obj:`float`, `optional`, defaults to 0.0):
|
||||
This value is subtracted from a beam's score if it generates a token same as any beam from other group
|
||||
at a particular time. Note that :obj:`diversity_penalty` is only effective if ``group beam search`` is
|
||||
enabled.
|
||||
num_return_sequences(:obj:`int`, `optional`, defaults to 1):
|
||||
The number of independently computed returned sequences for each element in the batch. Note that this
|
||||
is not the value we pass to the ``generator``'s `:func:`~transformers.PreTrainedModel.generate`
|
||||
|
@ -1326,6 +1335,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||
# set default parameters
|
||||
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
num_return_sequences = (
|
||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||
|
@ -1412,6 +1422,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||
eos_token_id=eos_token_id,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
num_beams=num_beams,
|
||||
num_beam_groups=num_beam_groups,
|
||||
diversity_penalty=diversity_penalty,
|
||||
)
|
||||
|
||||
if num_beams == 1:
|
||||
|
|
|
@ -118,6 +118,11 @@ class BeamSearchScorer:
|
|||
requires_pytorch(self)
|
||||
|
||||
|
||||
class HammingDiversityLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
@ -148,6 +153,11 @@ class NoRepeatNGramLogitsProcessor:
|
|||
requires_pytorch(self)
|
||||
|
||||
|
||||
class PrefixConstrainedLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class RepetitionPenaltyLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
|
|
@ -27,6 +27,7 @@ if is_torch_available():
|
|||
import torch.nn.functional as F
|
||||
|
||||
from transformers.generation_logits_process import (
|
||||
HammingDiversityLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
|
@ -302,3 +303,30 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||
self.assertListEqual(
|
||||
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
|
||||
)
|
||||
|
||||
def test_hamming_diversity(self):
|
||||
vocab_size = 4
|
||||
num_beams = 2
|
||||
num_beam_groups = 2
|
||||
|
||||
scores = self._get_uniform_logits(num_beams, vocab_size)
|
||||
# batch_idx = 0 -> index batch_idx * num_beam_groups -> idx = 0 * 2 = 0 -> penalises tokens 1
|
||||
# batch_idx = 1 -> index batch_idx * num_beam_groups -> idx = 1 * 2 = 2 -> penalises tokens 1
|
||||
current_tokens = torch.tensor([0, 3, 1, 2], device=torch_device, dtype=torch.long)
|
||||
|
||||
diversity_logits_processor = HammingDiversityLogitsProcessor(
|
||||
diversity_penalty=1.0, num_beams=num_beams, num_beam_groups=num_beam_groups
|
||||
)
|
||||
|
||||
processed_scores = diversity_logits_processor(None, scores, current_tokens, 1)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
processed_scores[0], torch.tensor([-0.7500, 0.2500, 0.2500, 0.2500], device=torch_device), atol=1e-3
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
processed_scores[1], torch.tensor([0.2500, -0.7500, 0.2500, 0.2500], device=torch_device), atol=1e-3
|
||||
)
|
||||
)
|
||||
|
|
|
@ -17,15 +17,16 @@
|
|||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import top_k_top_p_filtering
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering
|
||||
from transformers.generation_beam_search import BeamSearchScorer
|
||||
from transformers.generation_logits_process import (
|
||||
HammingDiversityLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
|
@ -61,7 +62,7 @@ class GenerationTesterMixin:
|
|||
return config, input_ids, attention_mask, max_length
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_kwargs(input_length, eos_token_id):
|
||||
def _get_logits_processor_and_kwargs(input_length, eos_token_id, diversity_penalty=None):
|
||||
process_kwargs = {
|
||||
"min_length": input_length + 1,
|
||||
"bad_words_ids": [[1, 0]],
|
||||
|
@ -70,6 +71,13 @@ class GenerationTesterMixin:
|
|||
}
|
||||
logits_processor = LogitsProcessorList(
|
||||
(
|
||||
[
|
||||
HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2),
|
||||
]
|
||||
if diversity_penalty is not None
|
||||
else []
|
||||
)
|
||||
+ (
|
||||
[
|
||||
MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id),
|
||||
]
|
||||
|
@ -115,6 +123,28 @@ class GenerationTesterMixin:
|
|||
)
|
||||
return beam_kwargs, beam_scorer
|
||||
|
||||
@staticmethod
|
||||
def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
"num_beams": 2,
|
||||
"num_return_sequences": num_return_sequences,
|
||||
"num_beam_groups": 2, # one beam per group
|
||||
"diversity_penalty": 2.0,
|
||||
}
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
device=torch_device,
|
||||
length_penalty=beam_kwargs["length_penalty"],
|
||||
do_early_stopping=beam_kwargs["early_stopping"],
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
num_beam_groups=beam_kwargs["num_beam_groups"],
|
||||
)
|
||||
return beam_kwargs, beam_scorer
|
||||
|
||||
@staticmethod
|
||||
def _get_encoder_outputs(model, input_ids, attention_mask, num_interleave=1):
|
||||
encoder = model.get_encoder()
|
||||
|
@ -408,6 +438,92 @@ class GenerationTesterMixin:
|
|||
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
def test_group_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0
|
||||
)
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# check `generate()` and `group_beam_search()` are equal
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
|
||||
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_group_beam_search = model.group_beam_search(
|
||||
input_ids_clone,
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_group_beam_search.tolist())
|
||||
|
||||
# check `generate()` and `group_beam_search()` are equal for `num_return_sequences`
|
||||
num_return_sequences = 2
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
|
||||
)
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_beam_search = model.group_beam_search(
|
||||
input_ids_clone,
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist())
|
||||
|
||||
|
||||
@require_torch
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
@ -512,3 +628,31 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||
|
||||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
||||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
||||
|
||||
|
||||
@require_torch
|
||||
class GenerationIntegrationTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_diverse_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
|
||||
The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People.
|
||||
"Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports.
|
||||
The couple announced the pregnancy in January, with an Instagram post. It is the first baby for both."""
|
||||
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = bart_model.generate(
|
||||
input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0
|
||||
)
|
||||
|
||||
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle name, as well as his father's first. It is the first baby for both of them.",
|
||||
"Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.",
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue