fixed prefix_allowed_tokens_fn docstring in generate() (#10862)

This commit is contained in:
RafaelWO 2021-03-23 18:48:22 +01:00 committed by GitHub
parent 7ef40120a0
commit d4d4447d53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -776,9 +776,9 @@ class GenerationMixin:
enabled.
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
provided no constraint is applied. This function takes 2 arguments: the batch ID :obj:`batch_id` and
:obj:`input_ids`. It has to return a list with the allowed tokens for the next generation step
conditioned on the batch ID :obj:`batch_id` and the previously generated tokens :obj:`inputs_ids`. This
argument is useful for constrained generation conditioned on the prefix, as described in
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
output_attentions (:obj:`bool`, `optional`, defaults to `False`):