🚨🚨 Fix beam score calculation issue for decoder-only models (#27351)
* Fix beam score calculation issue for decoder-only models * Update beam search test and fix code quality issue * Fix beam_sample, group_beam_search and constrained_beam_search * Split test for pytorch and TF, add documentation --------- Co-authored-by: Xin Qiu <xin.qiu@sentient.ai>
This commit is contained in:
parent
3d1a7bf476
commit
453079c7f8
|
@ -222,8 +222,10 @@ class BeamSearchScorer(BeamScorer):
|
|||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
group_index: Optional[int] = 0,
|
||||
decoder_prompt_len: Optional[int] = 0,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
|
||||
# add up to the length which the next_scores is calculated on
|
||||
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
|
||||
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
||||
|
||||
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
||||
|
@ -277,10 +279,15 @@ class BeamSearchScorer(BeamScorer):
|
|||
else:
|
||||
beam_index = None
|
||||
|
||||
# skip the corner case where the very first generated token is eos_token
|
||||
if decoder_prompt_len == input_ids.shape[-1]:
|
||||
continue
|
||||
|
||||
self._beam_hyps[batch_group_idx].add(
|
||||
input_ids[batch_beam_idx].clone(),
|
||||
next_score.item(),
|
||||
beam_indices=beam_index,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
else:
|
||||
# add next predicted token since it is not eos_token
|
||||
|
@ -322,6 +329,7 @@ class BeamSearchScorer(BeamScorer):
|
|||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
decoder_prompt_len: Optional[int] = 0,
|
||||
) -> Tuple[torch.LongTensor]:
|
||||
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
||||
|
||||
|
@ -340,7 +348,7 @@ class BeamSearchScorer(BeamScorer):
|
|||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
|
||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len)
|
||||
|
||||
# select the best hypotheses
|
||||
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
||||
|
@ -511,6 +519,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
decoder_prompt_len: Optional[int] = 0,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
|
@ -535,7 +544,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
beam_indices (`torch.LongTensor`, *optional*):
|
||||
Beam indices indicating to which beam hypothesis each token correspond.
|
||||
|
||||
decoder_prompt_len (`int`, *optional*):
|
||||
The length of prompt that is included in the input to decoder.
|
||||
Return:
|
||||
`UserDict`: A dictionary composed of the fields as defined above:
|
||||
|
||||
|
@ -550,7 +560,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||
indicating to which beam the next tokens shall be added.
|
||||
"""
|
||||
|
||||
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
|
||||
# add up to the length which the next_scores is calculated on
|
||||
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
|
||||
batch_size = len(self._beam_hyps)
|
||||
if not (batch_size == (input_ids.shape[0] // self.group_size)):
|
||||
if self.num_beam_groups > 1:
|
||||
|
@ -606,10 +617,16 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||
else:
|
||||
beam_index = None
|
||||
|
||||
# skip the corner case where the only constraint token is
|
||||
# eos_token and the very first generated token is eos_token
|
||||
if decoder_prompt_len == input_ids.shape[-1]:
|
||||
continue
|
||||
|
||||
beam_hyp.add(
|
||||
input_ids[batch_beam_idx].clone(),
|
||||
next_score.item(),
|
||||
beam_indices=beam_index,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
else:
|
||||
# add next predicted token since it is not eos_token
|
||||
|
@ -805,6 +822,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
decoder_prompt_len: Optional[int] = 0,
|
||||
) -> Tuple[torch.LongTensor]:
|
||||
batch_size = len(self._beam_hyps)
|
||||
|
||||
|
@ -828,7 +846,9 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
|
||||
if completes_constraint:
|
||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
|
||||
beam_hyp.add(
|
||||
final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len
|
||||
)
|
||||
ids_collect.append(beam_id)
|
||||
|
||||
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
|
||||
|
@ -839,7 +859,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
beam_hyp.add(final_tokens, final_score)
|
||||
beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len)
|
||||
if len(ids_collect) >= self.num_beam_hyps_to_keep:
|
||||
break
|
||||
|
||||
|
@ -931,11 +951,17 @@ class BeamHypotheses:
|
|||
"""
|
||||
return len(self.beams)
|
||||
|
||||
def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
|
||||
def add(
|
||||
self,
|
||||
hyp: torch.LongTensor,
|
||||
sum_logprobs: float,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
decoder_prompt_len: Optional[int] = 0,
|
||||
):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
||||
score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty)
|
||||
if len(self) < self.num_beams or score > self.worst_score:
|
||||
self.beams.append((score, hyp, beam_indices))
|
||||
if len(self) > self.num_beams:
|
||||
|
|
|
@ -3172,6 +3172,8 @@ class GenerationMixin:
|
|||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
|
||||
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
|
@ -3246,6 +3248,7 @@ class GenerationMixin:
|
|||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
beam_indices=beam_indices,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
|
||||
beam_scores = beam_outputs["next_beam_scores"]
|
||||
|
@ -3281,6 +3284,7 @@ class GenerationMixin:
|
|||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
beam_indices=beam_indices,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
|
@ -3500,6 +3504,8 @@ class GenerationMixin:
|
|||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
|
||||
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
|
@ -3578,6 +3584,7 @@ class GenerationMixin:
|
|||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
beam_indices=beam_indices,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
beam_scores = beam_outputs["next_beam_scores"]
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
|
@ -3612,6 +3619,7 @@ class GenerationMixin:
|
|||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
beam_indices=beam_indices,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
|
@ -3837,6 +3845,8 @@ class GenerationMixin:
|
|||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
|
||||
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
|
@ -3924,6 +3934,7 @@ class GenerationMixin:
|
|||
eos_token_id=eos_token_id,
|
||||
beam_indices=process_beam_indices,
|
||||
group_index=beam_group_idx,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
|
@ -3993,6 +4004,7 @@ class GenerationMixin:
|
|||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
beam_indices=final_beam_indices,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
|
@ -4220,6 +4232,8 @@ class GenerationMixin:
|
|||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
|
||||
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
|
@ -4298,6 +4312,7 @@ class GenerationMixin:
|
|||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
beam_indices=beam_indices,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
beam_scores = beam_outputs["next_beam_scores"]
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
|
@ -4331,6 +4346,7 @@ class GenerationMixin:
|
|||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
beam_indices=beam_indices,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
|
|
|
@ -633,7 +633,11 @@ class GenerationIntegrationTestsMixin:
|
|||
"do_sample": False,
|
||||
"num_beams": 3,
|
||||
}
|
||||
expectation = 13
|
||||
if is_pt:
|
||||
expectation = 20
|
||||
else:
|
||||
# TODO (joao): fix me
|
||||
expectation = 13
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
|
|
Loading…
Reference in New Issue