diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 5d1a44c00a..1db656da60 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -727,7 +727,7 @@ class WavLMEncoder(nn.Module): hidden_states, position_bias = layer_outputs[:2] if skip_the_layer: - layer_outputs = (None, None) + layer_outputs = (None, None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[2],) @@ -810,7 +810,7 @@ class WavLMEncoderStableLayerNorm(nn.Module): hidden_states, position_bias = layer_outputs[:2] if skip_the_layer: - layer_outputs = (None, None) + layer_outputs = (None, None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[2],) diff --git a/tests/models/wavlm/test_modeling_wavlm.py b/tests/models/wavlm/test_modeling_wavlm.py index c0a8eed209..3cf4348f6c 100644 --- a/tests/models/wavlm/test_modeling_wavlm.py +++ b/tests/models/wavlm/test_modeling_wavlm.py @@ -288,6 +288,15 @@ class WavLMModelTester: loss.backward() + def check_output_attentions(self, config, input_values, attention_mask): + model = WavLMModel(config=config) + model.config.layerdrop = 1.0 + model.to(torch_device) + model.train() + + outputs = model(input_values, attention_mask=attention_mask, output_attentions=True) + self.parent.assertTrue(len(outputs.attentions) > 0) + def check_labels_out_of_vocab(self, config, input_values, *args): model = WavLMForCTC(config) model.to(torch_device) @@ -354,6 +363,10 @@ class WavLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_seq_classifier_training(*config_and_inputs) + def test_output_attentions(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_output_attentions(*config_and_inputs) + def test_labels_out_of_vocab(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_labels_out_of_vocab(*config_and_inputs)