Switch Transformers: remove overwritten beam sample test (#25458)

This commit is contained in:
Joao Gante 2023-08-11 13:16:01 +01:00 committed by GitHub
parent 41d56ea6dd
commit 4692d26194
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 0 additions and 96 deletions

View File

@ -37,7 +37,6 @@ if is_torch_available():
SwitchTransformersModel,
SwitchTransformersTop1Router,
)
from transformers.generation import BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput
from transformers.models.switch_transformers.modeling_switch_transformers import (
SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST,
load_balancing_loss_func,
@ -613,101 +612,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
@slow
def test_beam_sample_generate_dict_output(self):
r"""
This test needs to be overriden with a larger model since it fails for very small models due to precision issues.
"""
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
input_ids.shape[0] * num_return_sequences, max_length
)
beam_kwargs["num_return_sequences"] = num_return_sequences
output_beam_sample, output_generate = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=num_return_sequences,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist())
@slow
def test_beam_sample_generate(self):
r"""
This test needs to be overriden with a larger model since it fails for very small models due to precision issues.
"""
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config.eos_token_id = None
config.forced_eos_token_id = None
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()
# check `generate()` and `beam_search()` are equal
# change `num_return_sequences = 2` but not for `beam_scorer`
num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
input_ids.shape[0] * num_return_sequences, max_length
)
beam_kwargs["num_return_sequences"] = num_return_sequences
output_generate, output_beam_sample = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=num_return_sequences,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
)
self.assertListEqual(output_generate.tolist(), output_beam_sample.tolist())
def test_decoder_model_past_with_3d_attn_mask(self):
(
config,