Fix overflowing bad word ids (#10889)

* Removes overflowing bad word IDs

* Raise warning
This commit is contained in:
Lysandre Debut 2021-03-24 15:13:56 -04:00 committed by GitHub
parent 1f5ea9e04a
commit 3c12e3c1c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 1 deletions

View File

@ -22,6 +22,10 @@ import numpy as np
import torch
from .file_utils import add_start_docstrings
from .utils.logging import get_logger
logger = get_logger(__name__)
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
@ -417,7 +421,14 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
banned_mask_list = []
for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
banned_mask_list.append([idx, token])
# Eliminates invalid bad word IDs that are over the vocabulary size.
if token <= scores.shape[1]:
banned_mask_list.append([idx, token])
else:
logger.error(
f"An invalid bad word ID is defined: {token}. This ID is not contained in the"
f"vocabulary, and is therefore ignored."
)
if not banned_mask_list:
return scores