Switch Transformers: remove overwritten beam sample test (#25458)
This commit is contained in:
parent
41d56ea6dd
commit
4692d26194
|
@ -37,7 +37,6 @@ if is_torch_available():
|
|||
SwitchTransformersModel,
|
||||
SwitchTransformersTop1Router,
|
||||
)
|
||||
from transformers.generation import BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput
|
||||
from transformers.models.switch_transformers.modeling_switch_transformers import (
|
||||
SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
load_balancing_loss_func,
|
||||
|
@ -613,101 +612,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
r"""
|
||||
This test needs to be overriden with a larger model since it fails for very small models due to precision issues.
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
|
||||
|
||||
num_return_sequences = 2
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0] * num_return_sequences, max_length
|
||||
)
|
||||
beam_kwargs["num_return_sequences"] = num_return_sequences
|
||||
|
||||
output_beam_sample, output_generate = self._beam_sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
num_return_sequences=num_return_sequences,
|
||||
beam_scorer=beam_scorer,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper=logits_warper,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||
|
||||
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist())
|
||||
|
||||
@slow
|
||||
def test_beam_sample_generate(self):
|
||||
r"""
|
||||
This test needs to be overriden with a larger model since it fails for very small models due to precision issues.
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
|
||||
|
||||
model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()
|
||||
|
||||
# check `generate()` and `beam_search()` are equal
|
||||
# change `num_return_sequences = 2` but not for `beam_scorer`
|
||||
num_return_sequences = 2
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0] * num_return_sequences, max_length
|
||||
)
|
||||
beam_kwargs["num_return_sequences"] = num_return_sequences
|
||||
|
||||
output_generate, output_beam_sample = self._beam_sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
num_return_sequences=num_return_sequences,
|
||||
beam_scorer=beam_scorer,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper=logits_warper,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
)
|
||||
|
||||
self.assertListEqual(output_generate.tolist(), output_beam_sample.tolist())
|
||||
|
||||
def test_decoder_model_past_with_3d_attn_mask(self):
|
||||
(
|
||||
config,
|
||||
|
|
Loading…
Reference in New Issue