[Performance improvement] "Bad tokens ids" optimization (#6064)
* Optimized banned token masking * Avoid duplicate EOS masking if in bad_words_id * Updated mask generation to handle empty banned token list * Addition of unit tests for the updated bad_words_ids masking * Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test * Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test (timeout does not work on Windows) * Moving Marian import to the test context to allow TF only environments to run * Moving imports to torch_available test * Updated operations device and test * Updated operations device and test * Added docstring and comment for in-place scores modification * Moving test to own test_generation_utils, use of lighter models for testing * removed unneded imports in test_modeling_common * revert formatting change for ModelTesterMixin * Updated caching, simplified eos token id test, removed unnecessary @require_torch * formatting compliance
This commit is contained in:
parent
87e124c245
commit
404782912a
|
@ -0,0 +1,90 @@
|
|||
import random
|
||||
import unittest
|
||||
|
||||
import timeout_decorator
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MarianConfig,
|
||||
MarianMTModel,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class GenerationUtilsTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def config(self):
|
||||
config = MarianConfig.from_pretrained("sshleifer/tiny-marian-en-de")
|
||||
return config
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
return MarianMTModel(self.config)
|
||||
|
||||
def test_postprocess_next_token_scores(self):
|
||||
config = self.config
|
||||
model = self.model
|
||||
# Initialize an input id tensor with batch size 8 and sequence length 12
|
||||
input_ids = torch.arange(0, 96, 1).view((8, 12))
|
||||
eos = config.eos_token_id
|
||||
bad_words_ids_test_cases = [[[299]], [[23, 24], [54]], [[config.eos_token_id]], []]
|
||||
masked_scores = [
|
||||
[(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)],
|
||||
[(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)],
|
||||
[(0, eos), (1, eos), (2, eos), (3, eos), (4, eos), (5, eos), (6, eos), (7, eos)],
|
||||
[],
|
||||
]
|
||||
|
||||
for test_case_index, bad_words_ids in enumerate(bad_words_ids_test_cases):
|
||||
# Initialize a scores tensor with batch size 8 and vocabulary size 300
|
||||
scores = torch.rand((8, 300))
|
||||
output = model.postprocess_next_token_scores(
|
||||
scores,
|
||||
input_ids,
|
||||
0,
|
||||
bad_words_ids,
|
||||
13,
|
||||
15,
|
||||
config.max_length,
|
||||
config.eos_token_id,
|
||||
config.repetition_penalty,
|
||||
32,
|
||||
5,
|
||||
)
|
||||
for masked_score in masked_scores[test_case_index]:
|
||||
self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf"))
|
||||
|
||||
@timeout_decorator.timeout(10)
|
||||
def test_postprocess_next_token_scores_large_bad_words_list(self):
|
||||
|
||||
config = self.config
|
||||
model = self.model
|
||||
# Initialize an input id tensor with batch size 8 and sequence length 12
|
||||
input_ids = torch.arange(0, 96, 1).view((8, 12))
|
||||
|
||||
bad_words_ids = []
|
||||
for _ in range(100):
|
||||
length_bad_word = random.randint(1, 4)
|
||||
bad_words_ids.append(random.sample(range(1, 300), length_bad_word))
|
||||
|
||||
scores = torch.rand((8, 300))
|
||||
_ = model.postprocess_next_token_scores(
|
||||
scores,
|
||||
input_ids,
|
||||
0,
|
||||
bad_words_ids,
|
||||
13,
|
||||
15,
|
||||
config.max_length,
|
||||
config.eos_token_id,
|
||||
config.repetition_penalty,
|
||||
32,
|
||||
5,
|
||||
)
|
|
@ -15,7 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Iterable, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -89,11 +89,12 @@ class GenerationMixin:
|
|||
scores[i, banned_tokens] = -float("inf")
|
||||
|
||||
if bad_words_ids is not None:
|
||||
# Exclude EOS token (already processed)
|
||||
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
|
||||
# calculate a list of banned tokens according to bad words
|
||||
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
|
||||
|
||||
for i, banned_tokens in enumerate(banned_tokens):
|
||||
scores[i, banned_tokens] = -float("inf")
|
||||
banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
|
||||
# Modify the scores in place by setting the banned tokens logits to `-inf`
|
||||
set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
|
||||
|
||||
return scores
|
||||
|
||||
|
@ -893,7 +894,7 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
|
|||
bad_words_ids
|
||||
)
|
||||
|
||||
if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
|
||||
if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
|
||||
# if tokens do not match continue
|
||||
continue
|
||||
|
||||
|
@ -904,6 +905,30 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
|
|||
return banned_tokens
|
||||
|
||||
|
||||
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
|
||||
""" Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
|
||||
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
|
||||
Args:
|
||||
scores: logits distribution of shape (batch size, vocabulary size)
|
||||
banned_tokens: list of list of tokens to ban of length (batch_size)
|
||||
"""
|
||||
banned_mask_list = []
|
||||
for idx, batch_banned_tokens in enumerate(banned_tokens):
|
||||
for token in batch_banned_tokens:
|
||||
banned_mask_list.append([idx, token])
|
||||
if not banned_mask_list:
|
||||
return
|
||||
banned_mask = torch.LongTensor(banned_mask_list)
|
||||
indices = torch.ones(len(banned_mask))
|
||||
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
|
||||
# [ 0 1 1 ]
|
||||
# [ 0 0 0 ]
|
||||
# [ 1 0 0 ]
|
||||
|
||||
banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
|
||||
scores.masked_fill_(banned_mask, -float("inf"))
|
||||
|
||||
|
||||
def top_k_top_p_filtering(
|
||||
logits: Tensor,
|
||||
top_k: int = 0,
|
||||
|
|
Loading…
Reference in New Issue