From c8545d2a9c38aa998b874e982b379610829af867 Mon Sep 17 00:00:00 2001 From: bofeng huang Date: Fri, 24 Feb 2023 11:07:52 +0100 Subject: [PATCH] [Whisper] Add SpecAugment (#21298) * Return and rescale attention_mask * Add SpecAugment to Whisper modeling * Fix test * Update docstring * Add SpecAug related parameters to model config * Add the _mask_input_features function to doc * Fix quality * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Remove dev comments * Add test * Resolve conflict * feat: mask {feature, time} prob fast tests * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: sanchit-gandhi Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/en/model_doc/whisper.mdx | 1 + .../models/whisper/configuration_whisper.py | 44 +++++ .../whisper/feature_extraction_whisper.py | 5 + .../models/whisper/modeling_whisper.py | 181 +++++++++++++++++- tests/models/whisper/test_modeling_whisper.py | 64 +++++++ 5 files changed, 293 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/whisper.mdx b/docs/source/en/model_doc/whisper.mdx index 9bd1094ce3..353348730a 100644 --- a/docs/source/en/model_doc/whisper.mdx +++ b/docs/source/en/model_doc/whisper.mdx @@ -72,6 +72,7 @@ The original code can be found [here](https://github.com/openai/whisper). [[autodoc]] WhisperModel - forward + - _mask_input_features ## WhisperForConditionalGeneration diff --git a/src/transformers/models/whisper/configuration_whisper.py b/src/transformers/models/whisper/configuration_whisper.py index 934d54be48..7aceffd5b4 100644 --- a/src/transformers/models/whisper/configuration_whisper.py +++ b/src/transformers/models/whisper/configuration_whisper.py @@ -136,6 +136,35 @@ class WhisperConfig(PretrainedConfig): begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`): A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as the token for `" "` (`blank_token_id`) and the `eos_token_id` + apply_spec_augment (`bool`, *optional*, defaults to `False`): + Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see + [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://arxiv.org/abs/1904.08779). + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procecure generates `mask_time_prob*len(time_axis)/mask_time_length` independent masks over the axis. If + reasoning from the propability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment == True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2),: + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procecure generates `mask_feature_prob*len(feature_axis)/mask_time_length` independent masks over + the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0),: + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + `mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`. Example: @@ -185,6 +214,13 @@ class WhisperConfig(PretrainedConfig): eos_token_id=50256, suppress_tokens=None, begin_suppress_tokens=[220, 50256], + apply_spec_augment=False, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, **kwargs, ): self.vocab_size = vocab_size @@ -208,6 +244,14 @@ class WhisperConfig(PretrainedConfig): self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.max_source_positions = max_source_positions self.max_target_positions = max_target_positions + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 9526e2815a..f234912028 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -307,6 +307,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): max_length=max_length if max_length else self.n_samples, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, ) # make sure list is in array format input_features = padded_inputs.get("input_features").transpose(2, 0, 1) @@ -318,6 +319,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): else: padded_inputs["input_features"] = input_features + if return_attention_mask: + # rescale from sample (48000) to feature (3000) + padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length] + if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index d2e28972ca..49c801987e 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -19,6 +19,7 @@ import math import random from typing import Optional, Tuple, Union +import numpy as np import torch import torch.utils.checkpoint from torch import nn @@ -97,6 +98,126 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.sum(-1).detach().tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + class WhisperPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): super().__init__(num_positions, embedding_dim) @@ -503,6 +624,14 @@ WHISPER_INPUTS_DOCSTRING = r""" the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing *SpecAugment* data augmentation on padding token indices. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Indices of decoder input sequence tokens in the vocabulary. @@ -999,11 +1128,55 @@ class WhisperModel(WhisperPreTrainedModel): """ self.encoder._freeze_parameters() + def _mask_input_features( + self, + input_features: torch.FloatTensor, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return input_features + + # generate indices & apply SpecAugment along time axis + batch_size, hidden_size, sequence_length = input_features.size() + + if self.config.mask_time_prob > 0 and self.training: + # generate indices & apply SpecAugment along time axis + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool) + mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1) + input_features[mask_time_indices] = 0 + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool) + input_features[mask_feature_indices] = 0 + + return input_features + @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_features: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, @@ -1044,6 +1217,8 @@ class WhisperModel(WhisperPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + encoder_outputs = self.encoder( input_features, head_mask=head_mask, @@ -1139,7 +1314,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_features: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, @@ -1193,6 +1369,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): outputs = self.model( input_features, + attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index bf092f9d35..b5d3b2e648 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -383,6 +383,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas expected_arg_names = [ "input_features", + "attention_mask", "decoder_input_ids", "decoder_attention_mask", ] @@ -909,6 +910,34 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas self.assertEqual(fx_keys, pt_keys) self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) + def test_mask_feature_prob(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.mask_feature_prob = 0.2 + config.mask_feature_length = 2 + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.train() + + # forward pass + encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state + self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16)) + + def test_mask_time_prob(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.mask_time_prob = 0.2 + config.mask_time_length = 2 + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.train() + + # forward pass + encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state + self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16)) + @require_torch @require_torchaudio @@ -1289,3 +1318,38 @@ class WhisperModelIntegrationTests(unittest.TestCase): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + @slow + def test_tiny_specaugment_librispeech(self): + torch_device = "cpu" + set_seed(0) + # Apply SpecAugment + model = WhisperModel.from_pretrained("openai/whisper-tiny", apply_spec_augment=True) + # Set model to training mode to enable SpecAugment + model.train() + model.to(torch_device) + input_speech = self._load_datasamples(1) + feature_extractor = WhisperFeatureExtractor() + input_features = feature_extractor(input_speech, return_tensors="pt").input_features + + with torch.no_grad(): + logits = model( + input_features, + decoder_input_ids=torch.tensor([[50258, 50259, 50359]]), + output_hidden_states=False, + output_attentions=False, + return_dict=False, + use_cache=False, + ) + + # fmt: off + EXPECTED_LOGITS = torch.tensor( + [ + 0.9362, -4.7105, 5.0879, 3.9642, 1.0013, -6.0096, 4.7285, -3.1847, + -0.8648, 1.9631, 6.2653, 3.6936, 0.3575, -4.5818, 3.0564, 7.8712, + 2.9951, 0.6848, 9.9497, -2.6638, 1.1571, -6.8546, -1.4333, -7.7584, + 1.1200, 3.9030, 4.4655, -4.4919, -1.1703, 9.6241 + ] + ) + # fmt: on + self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))