Track each row separately for stopping criteria (#29116)

This commit is contained in:
Raushan Turganbay 2024-02-26 21:06:16 +05:00 committed by GitHub
parent ece1b62b93
commit 8f2f0f0f85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 45 deletions

View File

@ -29,7 +29,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
Additional stopping criteria specific kwargs.
Return:
`bool`. `False` indicates we should continue, `True` indicates we should stop.
`torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation
for a particular row, `True` indicates we should continue.
"""
@ -42,7 +43,7 @@ class StoppingCriteria(ABC):
"""
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
raise NotImplementedError("StoppingCriteria needs to be subclassed")
@ -63,7 +64,7 @@ class MaxLengthCriteria(StoppingCriteria):
self.max_position_embeddings = max_position_embeddings
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
cur_len = input_ids.shape[-1]
is_done = cur_len >= self.max_length
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
@ -72,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
)
return is_done
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
class MaxNewTokensCriteria(StoppingCriteria):
@ -100,8 +101,9 @@ class MaxNewTokensCriteria(StoppingCriteria):
self.max_length = start_length + max_new_tokens
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids.shape[-1] >= self.max_length
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = input_ids.shape[-1] >= self.max_length
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
class MaxTimeCriteria(StoppingCriteria):
@ -122,14 +124,18 @@ class MaxTimeCriteria(StoppingCriteria):
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return time.time() - self.initial_timestamp > self.max_time
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = time.time() - self.initial_timestamp > self.max_time
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(criteria(input_ids, scores, **kwargs) for criteria in self)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device)
for criteria in self:
is_done = is_done | criteria(input_ids, scores, **kwargs)
return is_done
@property
def max_length(self) -> Optional[int]:

View File

@ -2195,11 +2195,9 @@ class GenerationMixin:
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
if unfinished_sequences.max() == 0:
this_peer_finished = True
if this_peer_finished and not synced_gpus:
@ -2478,14 +2476,12 @@ class GenerationMixin:
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished and not synced_gpus:
break
@ -2772,14 +2768,12 @@ class GenerationMixin:
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished and not synced_gpus:
break
@ -3169,7 +3163,7 @@ class GenerationMixin:
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
@ -3516,7 +3510,7 @@ class GenerationMixin:
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
@ -3912,7 +3906,7 @@ class GenerationMixin:
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
@ -4267,7 +4261,7 @@ class GenerationMixin:
# increase cur_len
cur_len = cur_len + 1
if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
@ -4657,14 +4651,12 @@ class GenerationMixin:
.prod(dim=0)
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished and not synced_gpus:
break

View File

@ -54,37 +54,37 @@ class StoppingCriteriaTestCase(unittest.TestCase):
]
)
self.assertFalse(criteria(input_ids, scores))
self.assertFalse(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
self.assertFalse(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
self.assertTrue(all(criteria(input_ids, scores)))
def test_max_length_criteria(self):
criteria = MaxLengthCriteria(max_length=10)
input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores))
self.assertFalse(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
self.assertFalse(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
self.assertTrue(all(criteria(input_ids, scores)))
def test_max_new_tokens_criteria(self):
criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)
input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores))
self.assertFalse(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
self.assertFalse(all(criteria(input_ids, scores)))
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
self.assertTrue(all(criteria(input_ids, scores)))
criteria_list = StoppingCriteriaList([criteria])
self.assertEqual(criteria_list.max_length, 10)
@ -93,10 +93,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids, scores = self._get_tensors(5)
criteria = MaxTimeCriteria(max_time=0.1)
self.assertFalse(criteria(input_ids, scores))
self.assertFalse(all(criteria(input_ids, scores)))
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
self.assertTrue(criteria(input_ids, scores))
self.assertTrue(all(criteria(input_ids, scores)))
def test_validate_stopping_criteria(self):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)