fixed prefix_allowed_tokens_fn docstring in generate() (#10862)

This commit is contained in:
RafaelWO 2021-03-23 18:48:22 +01:00 committed by GitHub
parent 7ef40120a0
commit d4d4447d53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -776,9 +776,9 @@ class GenerationMixin:
enabled. enabled.
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`): prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
If provided, this function constraints the beam search to allowed tokens only at each step. If not If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID provided no constraint is applied. This function takes 2 arguments: the batch ID :obj:`batch_id` and
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step :obj:`input_ids`. It has to return a list with the allowed tokens for the next generation step
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This conditioned on the batch ID :obj:`batch_id` and the previously generated tokens :obj:`inputs_ids`. This
argument is useful for constrained generation conditioned on the prefix, as described in argument is useful for constrained generation conditioned on the prefix, as described in
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__. `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
output_attentions (:obj:`bool`, `optional`, defaults to `False`): output_attentions (:obj:`bool`, `optional`, defaults to `False`):