Track each row separately for stopping criteria (#29116)
This commit is contained in:
parent
ece1b62b93
commit
8f2f0f0f85
|
@ -29,7 +29,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
|
|||
Additional stopping criteria specific kwargs.
|
||||
|
||||
Return:
|
||||
`bool`. `False` indicates we should continue, `True` indicates we should stop.
|
||||
`torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation
|
||||
for a particular row, `True` indicates we should continue.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -42,7 +43,7 @@ class StoppingCriteria(ABC):
|
|||
"""
|
||||
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
raise NotImplementedError("StoppingCriteria needs to be subclassed")
|
||||
|
||||
|
||||
|
@ -63,7 +64,7 @@ class MaxLengthCriteria(StoppingCriteria):
|
|||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
is_done = cur_len >= self.max_length
|
||||
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
|
||||
|
@ -72,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
|
|||
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
|
||||
"exceptions, performance degradation, or nothing at all."
|
||||
)
|
||||
return is_done
|
||||
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
|
||||
|
||||
|
||||
class MaxNewTokensCriteria(StoppingCriteria):
|
||||
|
@ -100,8 +101,9 @@ class MaxNewTokensCriteria(StoppingCriteria):
|
|||
self.max_length = start_length + max_new_tokens
|
||||
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
return input_ids.shape[-1] >= self.max_length
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
is_done = input_ids.shape[-1] >= self.max_length
|
||||
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
|
||||
|
||||
|
||||
class MaxTimeCriteria(StoppingCriteria):
|
||||
|
@ -122,14 +124,18 @@ class MaxTimeCriteria(StoppingCriteria):
|
|||
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
|
||||
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
return time.time() - self.initial_timestamp > self.max_time
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
is_done = time.time() - self.initial_timestamp > self.max_time
|
||||
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
|
||||
|
||||
|
||||
class StoppingCriteriaList(list):
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
return any(criteria(input_ids, scores, **kwargs) for criteria in self)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device)
|
||||
for criteria in self:
|
||||
is_done = is_done | criteria(input_ids, scores, **kwargs)
|
||||
return is_done
|
||||
|
||||
@property
|
||||
def max_length(self) -> Optional[int]:
|
||||
|
|
|
@ -2195,11 +2195,9 @@ class GenerationMixin:
|
|||
)
|
||||
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
|
||||
# stop if we exceed the maximum length
|
||||
if stopping_criteria(input_ids, scores):
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
|
||||
if this_peer_finished and not synced_gpus:
|
||||
|
@ -2478,14 +2476,12 @@ class GenerationMixin:
|
|||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
|
||||
# stop if we exceed the maximum length
|
||||
if stopping_criteria(input_ids, scores):
|
||||
this_peer_finished = True
|
||||
|
||||
if this_peer_finished and not synced_gpus:
|
||||
break
|
||||
|
||||
|
@ -2772,14 +2768,12 @@ class GenerationMixin:
|
|||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
|
||||
# stop if we exceed the maximum length
|
||||
if stopping_criteria(input_ids, scores):
|
||||
this_peer_finished = True
|
||||
|
||||
if this_peer_finished and not synced_gpus:
|
||||
break
|
||||
|
||||
|
@ -3169,7 +3163,7 @@ class GenerationMixin:
|
|||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
||||
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
|
@ -3516,7 +3510,7 @@ class GenerationMixin:
|
|||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
||||
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
|
@ -3912,7 +3906,7 @@ class GenerationMixin:
|
|||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
||||
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
|
@ -4267,7 +4261,7 @@ class GenerationMixin:
|
|||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
||||
if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
|
@ -4657,14 +4651,12 @@ class GenerationMixin:
|
|||
.prod(dim=0)
|
||||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
|
||||
# stop if we exceed the maximum length
|
||||
if stopping_criteria(input_ids, scores):
|
||||
this_peer_finished = True
|
||||
|
||||
if this_peer_finished and not synced_gpus:
|
||||
break
|
||||
|
||||
|
|
|
@ -54,37 +54,37 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
|||
]
|
||||
)
|
||||
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(9)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
def test_max_length_criteria(self):
|
||||
criteria = MaxLengthCriteria(max_length=10)
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(9)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
def test_max_new_tokens_criteria(self):
|
||||
criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(9)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
criteria_list = StoppingCriteriaList([criteria])
|
||||
self.assertEqual(criteria_list.max_length, 10)
|
||||
|
@ -93,10 +93,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
|||
input_ids, scores = self._get_tensors(5)
|
||||
|
||||
criteria = MaxTimeCriteria(max_time=0.1)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
def test_validate_stopping_criteria(self):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
|
||||
|
|
Loading…
Reference in New Issue