diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index b7d561445f..3935566de0 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -36,7 +36,10 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Wav2Vec2Config" WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "facebook/wav2vec2-base-960h" + "facebook/wav2vec2-base-960h", + "facebook/wav2vec2-large-960h", + "facebook/wav2vec2-large-960h-lv60", + "facebook/wav2vec2-large-960h-lv60-self", # See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2 ] @@ -191,7 +194,6 @@ class Wav2Vec2FeatureProjection(nn.Module): self.dropout = nn.Dropout(config.feat_extract_dropout) def forward(self, hidden_states): - hidden_states = hidden_states.transpose(1, 2) hidden_states = self.layer_norm(hidden_states) hidden_states = self.projection(hidden_states) hidden_states = self.dropout(hidden_states) @@ -387,9 +389,11 @@ class Wav2Vec2EncoderLayer(nn.Module): self.feed_forward = Wav2Vec2FeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - def forward(self, hidden_states, output_attentions=False): + def forward(self, hidden_states, attention_mask=None, output_attentions=False): attn_residual = hidden_states - hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions) + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states @@ -414,10 +418,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module): self.feed_forward = Wav2Vec2FeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - def forward(self, hidden_states, output_attentions=False): + def forward(self, hidden_states, attention_mask=None, output_attentions=False): attn_residual = hidden_states hidden_states = self.layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions) + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) @@ -438,6 +444,7 @@ class Wav2Vec2Encoder(nn.Module): def forward( self, hidden_states, + attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, @@ -445,6 +452,16 @@ class Wav2Vec2Encoder(nn.Module): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states[~attention_mask] = 0.0 + + # extend attention_mask + attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.layer_norm(hidden_states) @@ -454,7 +471,9 @@ class Wav2Vec2Encoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions) + hidden_states, attn_weights = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) if output_attentions: all_self_attentions = all_self_attentions + (attn_weights,) @@ -486,6 +505,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): def forward( self, hidden_states, + attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, @@ -493,6 +513,16 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None + if attention_mask is not None: + # make sure padded tokens are not attended to + hidden_states[~attention_mask] = 0 + + # extend attention_mask + attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) @@ -501,7 +531,9 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions) + hidden_states, attn_weights = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) if output_attentions: all_self_attentions = all_self_attentions + (attn_weights,) @@ -544,6 +576,21 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.floor((input_length - kernel_size) / stride + 1) + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths.to(torch.long) + WAV_2_VEC_2_START_DOCSTRING = r""" Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations @@ -572,6 +619,24 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r""" soundfile`). To prepare the array into `input_values`, the :class:`~transformers.Wav2Vec2Tokenizer` should be used for padding and conversion into a tensor of type `torch.FloatTensor`. See :meth:`transformers.Wav2Vec2Tokenizer.__call__` for details. + attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in ``[0, + 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + .. warning:: + :obj:`attention_mask` should only be passed if the corresponding tokenizer has + ``config.return_attention_mask == True``. For all models whose tokenizer has + ``config.return_attention_mask == False``, such as `wav2vec2-base + `__, :obj:`attention_mask` should **not** be passed + to avoid degraded performance when doing batched inference. For such models :obj:`input_values` should + simply be padded with 0 and passed without :obj:`attention_mask`. Be aware that these models also yield + slightly different results depending on whether :obj:`input_values` is padded or not. + output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -606,6 +671,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): def forward( self, input_values, + attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -641,14 +707,33 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict hidden_states = self.feature_extractor(input_values) + hidden_states = hidden_states.transpose(1, 2) + + if attention_mask is not None: + # compute real output lengths according to convolution formula + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) + + attention_mask = torch.zeros( + hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + attention_mask[ + (torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1) + ] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + hidden_states = self.feature_projection(hidden_states) encoder_outputs = self.encoder( hidden_states, + attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) + hidden_states = encoder_outputs[0] if not return_dict: @@ -681,6 +766,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): def forward( self, input_values, + attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -755,6 +841,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): def forward( self, input_values, + attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -795,6 +882,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): outputs = self.wav2vec2( input_values, + attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -802,6 +890,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): hidden_states = outputs[0] hidden_states = self.dropout(hidden_states) + logits = self.lm_head(hidden_states) if not return_dict: diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index ce46a011d1..0cc491cf37 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -87,6 +87,26 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer): The token used for padding, for example when batching sequences of different lengths. word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`): The token used for defining the end of a word. + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to lowercase the output when decoding. + do_normalize (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly + improve the performance for some models, *e.g.*, `wav2vec2-lv60 + `__. + return_attention_mask (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not :meth:`~transformers.Wav2Vec2Tokenizer.__call__` should return :obj:`attention_mask`. + + .. note:: + + Wav2Vec2 models that have set ``config.feat_extract_norm == "group"``, such as `wav2vec2-base + `__, have **not** been trained using + :obj:`attention_mask`. For such models, :obj:`input_values` should simply be padded with 0 and no + :obj:`attention_mask` should be passed. + + For Wav2Vec2 models that have set ``config.feat_extract_norm == "layer"``, such as `wav2vec2-lv60 + `__, :obj:`attention_mask` should be + passed for batched inference. + **kwargs Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer` """ @@ -100,7 +120,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer): "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json", }, } - model_input_names = ["input_values"] + model_input_names = ["input_values", "attention_mask"] def __init__( self, @@ -111,6 +131,8 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer): pad_token="", word_delimiter_token="|", do_lower_case=False, + do_normalize=False, + return_attention_mask=False, **kwargs ): super().__init__( @@ -119,11 +141,16 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer): eos_token=eos_token, pad_token=pad_token, do_lower_case=do_lower_case, + do_normalize=do_normalize, + return_attention_mask=return_attention_mask, word_delimiter_token=word_delimiter_token, **kwargs, ) self._word_delimiter_token = word_delimiter_token + self.do_lower_case = do_lower_case + self.return_attention_mask = return_attention_mask + self.do_normalize = do_normalize with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) @@ -193,6 +220,10 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer): if not is_batched: raw_speech = [raw_speech] + # zero-mean and unit-variance normalization + if self.do_normalize: + raw_speech = [(x - np.mean(x)) / np.sqrt(np.var(x) + 1e-5) for x in raw_speech] + # convert into correct format for padding encoded_inputs = BatchEncoding({"input_values": raw_speech}) @@ -201,7 +232,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer): padding=padding, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=False, + return_attention_mask=self.return_attention_mask, return_tensors=return_tensors, verbose=verbose, ) diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index b9e726633d..5cb23672e4 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -18,7 +18,7 @@ import math import unittest -from tests.test_modeling_common import floats_tensor +from tests.test_modeling_common import floats_tensor, random_attention_mask from transformers import is_torch_available from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device @@ -93,6 +93,7 @@ class Wav2Vec2ModelTester: def prepare_config_and_inputs(self): input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = random_attention_mask([self.batch_size, self.seq_length]) config = Wav2Vec2Config( hidden_size=self.hidden_size, @@ -115,20 +116,48 @@ class Wav2Vec2ModelTester: vocab_size=self.vocab_size, ) - return config, input_values + return config, input_values, attention_mask - def create_and_check_model(self, config, input_values): + def create_and_check_model(self, config, input_values, attention_mask): model = Wav2Vec2Model(config=config) model.to(torch_device) model.eval() - result = model(input_values) + result = model(input_values, attention_mask=attention_mask) self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size) ) + def create_and_check_batch_inference(self, config, input_values, *args): + # Not sure how to make this test pass at the moment. Batched input yields + # same results as official fairseq implementation, but gives different results + # depending on whether batched input is used or not + # check: https://github.com/pytorch/fairseq/issues/3227 + model = Wav2Vec2Model(config=config) + model.to(torch_device) + model.eval() + + input_values = input_values[:3] + attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0.0 + + batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state + + for i in range(input_values.shape[0]): + input_slice = input_values[i : i + 1, : input_lengths[i]] + output = model(input_slice).last_hidden_state + + batch_output = batch_outputs[i : i + 1, : output.shape[1]] + self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3)) + def prepare_config_and_inputs_for_common(self): - config, input_values = self.prepare_config_and_inputs() - inputs_dict = {"input_values": input_values} + config, input_values, attention_mask = self.prepare_config_and_inputs() + inputs_dict = {"input_values": input_values, "attention_mask": attention_mask} return config, inputs_dict @@ -222,6 +251,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_batched_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_batch_inference(*config_and_inputs) + # Wav2Vec2 has no inputs_embeds def test_inputs_embeds(self): pass @@ -288,7 +321,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): return ds["speech"][:num_samples] - def test_inference_masked_lm_normal(self): + def test_inference_ctc_normal(self): model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") model.to(torch_device) tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True) @@ -306,16 +339,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"] self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) - def test_inference_masked_lm_normal_batched(self): + def test_inference_ctc_normal_batched(self): model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") model.to(torch_device) tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True) input_speech = self._load_datasamples(2) - input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to( - torch_device - ) + inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True) + + input_values = inputs.input_values.to(torch_device) with torch.no_grad(): logits = model(input_values).logits @@ -329,18 +362,19 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ] self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) - def test_inference_masked_lm_robust_batched(self): + def test_inference_ctc_robust_batched(self): model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device) tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True) input_speech = self._load_datasamples(4) - input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to( - torch_device - ) + inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) with torch.no_grad(): - logits = model(input_values).logits + logits = model(input_values, attention_mask=attention_mask).logits predicted_ids = torch.argmax(logits, dim=-1) predicted_trans = tokenizer.batch_decode(predicted_ids) diff --git a/tests/test_tokenization_wav2vec2.py b/tests/test_tokenization_wav2vec2.py index 38ab17e099..9b7c0c33b4 100644 --- a/tests/test_tokenization_wav2vec2.py +++ b/tests/test_tokenization_wav2vec2.py @@ -23,7 +23,10 @@ import unittest import numpy as np -from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2Tokenizer +from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST +from transformers.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2Tokenizer +from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES +from transformers.testing_utils import slow global_rng = random.Random() @@ -299,3 +302,46 @@ class Wav2Vec2TokenizerTest(unittest.TestCase): for parameter_name, parameter in signature.parameters.items(): if parameter.default != inspect.Parameter.empty: self.assertIn(parameter_name, tokenizer.init_kwargs) + + def test_zero_mean_unit_variance_normalization(self): + tokenizer = self.get_tokenizer(do_normalize=True) + speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + processed = tokenizer(speech_inputs, padding="longest") + input_values = processed.input_values + + def _check_zero_mean_unit_variance(input_vector): + self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3) + self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3) + + _check_zero_mean_unit_variance(input_values[0, :800]) + _check_zero_mean_unit_variance(input_values[1, :1000]) + _check_zero_mean_unit_variance(input_values[2]) + + def test_return_attention_mask(self): + speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + + # default case -> no attention_mask is returned + tokenizer = self.get_tokenizer() + processed = tokenizer(speech_inputs) + self.assertNotIn("attention_mask", processed) + + # wav2vec2-lv60 -> return attention_mask + tokenizer = self.get_tokenizer(return_attention_mask=True) + processed = tokenizer(speech_inputs, padding="longest") + + self.assertIn("attention_mask", processed) + self.assertListEqual(list(processed.attention_mask.shape), list(processed.input_values.shape)) + self.assertListEqual(processed.attention_mask.sum(-1).tolist(), [800, 1000, 1200]) + + @slow + def test_pretrained_checkpoints_are_set_correctly(self): + # this test makes sure that models that are using + # group norm don't have their tokenizer return the + # attention_mask + for model_id in WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST: + config = Wav2Vec2Config.from_pretrained(model_id) + tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_id) + + # only "layer" feature extraction norm should make use of + # attention_mask + self.assertEqual(tokenizer.return_attention_mask, config.feat_extract_norm == "layer")