Docs: add missing `StoppingCriteria` autodocs (#30617)

* add missing docstrings to docs

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Joao Gante 2024-05-02 15:20:04 +01:00 committed by GitHub
parent aa55ff44a2
commit 66abe13951
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 25 additions and 1 deletions

View File

@ -310,6 +310,12 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
[[autodoc]] MaxTimeCriteria
- __call__
[[autodoc]] StopStringCriteria
- __call__
[[autodoc]] EosTokenCriteria
- __call__
## Constraints
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusively available to our PyTorch implementations.

View File

@ -1419,6 +1419,7 @@ else:
"DisjunctiveConstraint",
"EncoderNoRepeatNGramLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor",
"EosTokenCriteria",
"EpsilonLogitsWarper",
"EtaLogitsWarper",
"ExponentialDecayLengthPenalty",
@ -1444,6 +1445,7 @@ else:
"SequenceBiasLogitsProcessor",
"StoppingCriteria",
"StoppingCriteriaList",
"StopStringCriteria",
"SuppressTokensAtBeginLogitsProcessor",
"SuppressTokensLogitsProcessor",
"TemperatureLogitsWarper",
@ -6372,6 +6374,7 @@ if TYPE_CHECKING:
DisjunctiveConstraint,
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EosTokenCriteria,
EpsilonLogitsWarper,
EtaLogitsWarper,
ExponentialDecayLengthPenalty,
@ -6397,6 +6400,7 @@ if TYPE_CHECKING:
SequenceBiasLogitsProcessor,
StoppingCriteria,
StoppingCriteriaList,
StopStringCriteria,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
TemperatureLogitsWarper,

View File

@ -98,7 +98,7 @@ class MaxNewTokensCriteria(StoppingCriteria):
def __init__(self, start_length: int, max_new_tokens: int):
warnings.warn(
"The class `MaxNewTokensCriteria` is deprecated. "
"The class `MaxNewTokensCriteria` is deprecated and will be removed in v4.43. "
f"Please use `MaxLengthCriteria(max_length={start_length + max_new_tokens})` "
"with `max_length = start_length + max_new_tokens` instead.",
FutureWarning,

View File

@ -177,6 +177,13 @@ class EncoderRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"])
class EosTokenCriteria(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EpsilonLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
@ -352,6 +359,13 @@ class StoppingCriteriaList(metaclass=DummyObject):
requires_backends(self, ["torch"])
class StopStringCriteria(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]