🚨🚨 Fix beam score calculation issue for decoder-only models (#27351)

* Fix beam score calculation issue for decoder-only models

* Update beam search test and fix code quality issue

* Fix beam_sample, group_beam_search and constrained_beam_search

* Split test for pytorch and TF, add documentation

---------

Co-authored-by: Xin Qiu <xin.qiu@sentient.ai>
This commit is contained in:
Xin Qiu 2023-11-15 20:49:14 +08:00 committed by GitHub
parent 3d1a7bf476
commit 453079c7f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 9 deletions

View File

@ -222,8 +222,10 @@ class BeamSearchScorer(BeamScorer):
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
) -> Dict[str, torch.Tensor]:
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
# add up to the length which the next_scores is calculated on
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
batch_size = len(self._beam_hyps) // self.num_beam_groups
if not (batch_size == (input_ids.shape[0] // self.group_size)):
@ -277,10 +279,15 @@ class BeamSearchScorer(BeamScorer):
else:
beam_index = None
# skip the corner case where the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue
self._beam_hyps[batch_group_idx].add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len,
)
else:
# add next predicted token since it is not eos_token
@ -322,6 +329,7 @@ class BeamSearchScorer(BeamScorer):
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps) // self.num_beam_groups
@ -340,7 +348,7 @@ class BeamSearchScorer(BeamScorer):
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len)
# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
@ -511,6 +519,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.Tensor]:
r"""
Args:
@ -535,7 +544,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`torch.LongTensor`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
decoder_prompt_len (`int`, *optional*):
The length of prompt that is included in the input to decoder.
Return:
`UserDict`: A dictionary composed of the fields as defined above:
@ -550,7 +560,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
indicating to which beam the next tokens shall be added.
"""
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
# add up to the length which the next_scores is calculated on
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
@ -606,10 +617,16 @@ class ConstrainedBeamSearchScorer(BeamScorer):
else:
beam_index = None
# skip the corner case where the only constraint token is
# eos_token and the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue
beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len,
)
else:
# add next predicted token since it is not eos_token
@ -805,6 +822,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)
@ -828,7 +846,9 @@ class ConstrainedBeamSearchScorer(BeamScorer):
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
if completes_constraint:
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
beam_hyp.add(
final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len
)
ids_collect.append(beam_id)
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
@ -839,7 +859,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score)
beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len)
if len(ids_collect) >= self.num_beam_hyps_to_keep:
break
@ -931,11 +951,17 @@ class BeamHypotheses:
"""
return len(self.beams)
def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
def add(
self,
hyp: torch.LongTensor,
sum_logprobs: float,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:

View File

@ -3172,6 +3172,8 @@ class GenerationMixin:
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
@ -3246,6 +3248,7 @@ class GenerationMixin:
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
beam_scores = beam_outputs["next_beam_scores"]
@ -3281,6 +3284,7 @@ class GenerationMixin:
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
if return_dict_in_generate:
@ -3500,6 +3504,8 @@ class GenerationMixin:
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
@ -3578,6 +3584,7 @@ class GenerationMixin:
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
@ -3612,6 +3619,7 @@ class GenerationMixin:
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
if return_dict_in_generate:
@ -3837,6 +3845,8 @@ class GenerationMixin:
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
@ -3924,6 +3934,7 @@ class GenerationMixin:
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
group_index=beam_group_idx,
decoder_prompt_len=decoder_prompt_len,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
@ -3993,6 +4004,7 @@ class GenerationMixin:
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
if return_dict_in_generate:
@ -4220,6 +4232,8 @@ class GenerationMixin:
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
@ -4298,6 +4312,7 @@ class GenerationMixin:
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
@ -4331,6 +4346,7 @@ class GenerationMixin:
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
if return_dict_in_generate:

View File

@ -633,7 +633,11 @@ class GenerationIntegrationTestsMixin:
"do_sample": False,
"num_beams": 3,
}
expectation = 13
if is_pt:
expectation = 20
else:
# TODO (joao): fix me
expectation = 13
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""