Fix overflowing bad word ids (#10889)
* Removes overflowing bad word IDs * Raise warning
This commit is contained in:
parent
1f5ea9e04a
commit
3c12e3c1c4
|
@ -22,6 +22,10 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
from .utils.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||
|
@ -417,7 +421,14 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
|
|||
banned_mask_list = []
|
||||
for idx, batch_banned_tokens in enumerate(banned_tokens):
|
||||
for token in batch_banned_tokens:
|
||||
banned_mask_list.append([idx, token])
|
||||
# Eliminates invalid bad word IDs that are over the vocabulary size.
|
||||
if token <= scores.shape[1]:
|
||||
banned_mask_list.append([idx, token])
|
||||
else:
|
||||
logger.error(
|
||||
f"An invalid bad word ID is defined: {token}. This ID is not contained in the"
|
||||
f"vocabulary, and is therefore ignored."
|
||||
)
|
||||
if not banned_mask_list:
|
||||
return scores
|
||||
|
||||
|
|
Loading…
Reference in New Issue