Generate: replace breaks by a loop condition (#29662)
* replace breaks by a loop condition * Update src/transformers/generation/utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
28de2f4de3
commit
9e4df7c424
|
@ -1778,6 +1778,24 @@ class GenerationMixin:
|
|||
|
||||
return result
|
||||
|
||||
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
|
||||
"""
|
||||
Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
|
||||
fed through `this_peer_finished`. ZeRO stage 3-friendly.
|
||||
"""
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
return False
|
||||
elif this_peer_finished:
|
||||
return False
|
||||
return True
|
||||
|
||||
def contrastive_search(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||
|
@ -1939,19 +1957,9 @@ class GenerationMixin:
|
|||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
this_peer_finished = False
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
|
||||
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
|
||||
if model_kwargs.get("past_key_values") is None:
|
||||
|
@ -2187,12 +2195,7 @@ class GenerationMixin:
|
|||
|
||||
# stop when each sentence is finished
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
|
||||
if this_peer_finished and not synced_gpus:
|
||||
break
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
@ -2395,6 +2398,7 @@ class GenerationMixin:
|
|||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
this_peer_finished = False
|
||||
batch_size, cur_len = (
|
||||
model_kwargs["attention_mask"].shape
|
||||
if model_kwargs.get("attention_mask", None) is not None
|
||||
|
@ -2403,18 +2407,7 @@ class GenerationMixin:
|
|||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
|
@ -2480,13 +2473,7 @@ class GenerationMixin:
|
|||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
|
||||
if this_peer_finished and not synced_gpus:
|
||||
break
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
@ -2699,6 +2686,7 @@ class GenerationMixin:
|
|||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
this_peer_finished = False
|
||||
batch_size, cur_len = (
|
||||
model_kwargs["attention_mask"].shape
|
||||
if model_kwargs.get("attention_mask", None) is not None
|
||||
|
@ -2707,19 +2695,7 @@ class GenerationMixin:
|
|||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
# auto-regressive generation
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
|
@ -2787,13 +2763,7 @@ class GenerationMixin:
|
|||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
|
||||
if this_peer_finished and not synced_gpus:
|
||||
break
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
@ -3052,20 +3022,11 @@ class GenerationMixin:
|
|||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
this_peer_finished = False
|
||||
|
||||
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
# if sequential is True, split the input to batches of batch_size and run sequentially
|
||||
|
@ -3192,10 +3153,7 @@ class GenerationMixin:
|
|||
cur_len = cur_len + 1
|
||||
|
||||
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
this_peer_finished = True
|
||||
this_peer_finished = True
|
||||
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids,
|
||||
|
@ -3441,20 +3399,10 @@ class GenerationMixin:
|
|||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
this_peer_finished = False
|
||||
|
||||
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
outputs = self(
|
||||
|
@ -3549,10 +3497,7 @@ class GenerationMixin:
|
|||
cur_len = cur_len + 1
|
||||
|
||||
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
this_peer_finished = True
|
||||
this_peer_finished = True
|
||||
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids,
|
||||
|
@ -3804,20 +3749,10 @@ class GenerationMixin:
|
|||
beam_scores[:, ::num_sub_beams] = 0
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
this_peer_finished = False
|
||||
|
||||
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
# predicted tokens in cur_len step
|
||||
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
||||
|
||||
|
@ -3955,10 +3890,7 @@ class GenerationMixin:
|
|||
cur_len = cur_len + 1
|
||||
|
||||
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
this_peer_finished = True
|
||||
this_peer_finished = True
|
||||
|
||||
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
|
@ -4213,20 +4145,10 @@ class GenerationMixin:
|
|||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
this_peer_finished = False
|
||||
|
||||
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
outputs = self(
|
||||
|
@ -4320,10 +4242,7 @@ class GenerationMixin:
|
|||
cur_len = cur_len + 1
|
||||
|
||||
if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
this_peer_finished = True
|
||||
this_peer_finished = True
|
||||
|
||||
sequence_outputs = constrained_beam_scorer.finalize(
|
||||
input_ids,
|
||||
|
@ -4553,18 +4472,8 @@ class GenerationMixin:
|
|||
# other auxiliary variables
|
||||
max_len = stopping_criteria[0].max_length
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
this_peer_finished = False
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
# 1. Fetch candidate sequences from a `CandidateGenerator`
|
||||
|
@ -4733,13 +4642,7 @@ class GenerationMixin:
|
|||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0:
|
||||
this_peer_finished = True
|
||||
|
||||
if this_peer_finished and not synced_gpus:
|
||||
break
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
|
Loading…
Reference in New Issue