diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index c28e0afbb9..58a2bf13ba 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -463,6 +463,7 @@ class FlaxGenerationMixin: logits_processor=logits_processor, trace=trace, params=params, + num_return_sequences=generation_config.num_return_sequences, model_kwargs=model_kwargs, ) else: @@ -749,6 +750,7 @@ class FlaxGenerationMixin: logits_processor: Optional[FlaxLogitsProcessorList] = None, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, + num_return_sequences: Optional[int] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, ): """ @@ -793,6 +795,9 @@ class FlaxGenerationMixin: eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping + num_return_sequences = ( + num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences + ) batch_size, num_beams, cur_len = input_ids.shape @@ -996,8 +1001,8 @@ class FlaxGenerationMixin: sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences) scores = jnp.where(none_finished[:, None], state.scores, state.running_scores) - # take best beam for each batch - sequences = sequences[:, 0] - scores = scores[:, 0] + # Take best beams for each batch (the score is sorted in descending order) + sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :]) + scores = flatten_beam_dim(scores[:, :num_return_sequences]) return FlaxBeamSearchOutput(sequences=sequences, scores=scores) diff --git a/tests/generation/test_flax_utils.py b/tests/generation/test_flax_utils.py index c6182a2386..647482b88c 100644 --- a/tests/generation/test_flax_utils.py +++ b/tests/generation/test_flax_utils.py @@ -158,6 +158,19 @@ class FlaxGenerationTesterMixin: self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) + def test_beam_search_generate_num_return_sequences(self): + config, input_ids, _, max_length = self._get_input_ids_and_config() + config.do_sample = False + config.max_length = max_length + config.num_beams = 2 + config.num_return_sequences = 2 + + for model_class in self.all_generative_model_classes: + model = model_class(config) + + generation_outputs = model.generate(input_ids).sequences + self.assertEqual(generation_outputs.shape[0], input_ids.shape[0] * config.num_return_sequences) + def test_sample_generate_logits_warper(self): config, input_ids, _, max_length = self._get_input_ids_and_config() config.do_sample = True