[Whisper] Add SpecAugment (#21298)
* Return and rescale attention_mask * Add SpecAugment to Whisper modeling * Fix test * Update docstring * Add SpecAug related parameters to model config * Add the _mask_input_features function to doc * Fix quality * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Remove dev comments * Add test * Resolve conflict * feat: mask {feature, time} prob fast tests * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: sanchit-gandhi <sanchit@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
75bd49ff88
commit
c8545d2a9c
|
@ -72,6 +72,7 @@ The original code can be found [here](https://github.com/openai/whisper).
|
|||
|
||||
[[autodoc]] WhisperModel
|
||||
- forward
|
||||
- _mask_input_features
|
||||
|
||||
## WhisperForConditionalGeneration
|
||||
|
||||
|
|
|
@ -136,6 +136,35 @@ class WhisperConfig(PretrainedConfig):
|
|||
begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`):
|
||||
A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as
|
||||
the token for `" "` (`blank_token_id`) and the `eos_token_id`
|
||||
apply_spec_augment (`bool`, *optional*, defaults to `False`):
|
||||
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
|
||||
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
||||
Recognition](https://arxiv.org/abs/1904.08779).
|
||||
mask_time_prob (`float`, *optional*, defaults to 0.05):
|
||||
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
|
||||
procecure generates `mask_time_prob*len(time_axis)/mask_time_length` independent masks over the axis. If
|
||||
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
|
||||
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
|
||||
actual percentage of masked vectors. This is only relevant if `apply_spec_augment == True`.
|
||||
mask_time_length (`int`, *optional*, defaults to 10):
|
||||
Length of vector span along the time axis.
|
||||
mask_time_min_masks (`int`, *optional*, defaults to 2),:
|
||||
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
|
||||
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
|
||||
mask_time_min_masks''
|
||||
mask_feature_prob (`float`, *optional*, defaults to 0.0):
|
||||
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
|
||||
masking procecure generates `mask_feature_prob*len(feature_axis)/mask_time_length` independent masks over
|
||||
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
|
||||
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
|
||||
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
|
||||
True`.
|
||||
mask_feature_length (`int`, *optional*, defaults to 10):
|
||||
Length of vector span along the feature axis.
|
||||
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
|
||||
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
|
||||
step, irrespectively of `mask_feature_prob`. Only relevant if
|
||||
`mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`.
|
||||
|
||||
|
||||
Example:
|
||||
|
@ -185,6 +214,13 @@ class WhisperConfig(PretrainedConfig):
|
|||
eos_token_id=50256,
|
||||
suppress_tokens=None,
|
||||
begin_suppress_tokens=[220, 50256],
|
||||
apply_spec_augment=False,
|
||||
mask_time_prob=0.05,
|
||||
mask_time_length=10,
|
||||
mask_time_min_masks=2,
|
||||
mask_feature_prob=0.0,
|
||||
mask_feature_length=10,
|
||||
mask_feature_min_masks=0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
|
@ -208,6 +244,14 @@ class WhisperConfig(PretrainedConfig):
|
|||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
self.max_source_positions = max_source_positions
|
||||
self.max_target_positions = max_target_positions
|
||||
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
||||
self.apply_spec_augment = apply_spec_augment
|
||||
self.mask_time_prob = mask_time_prob
|
||||
self.mask_time_length = mask_time_length
|
||||
self.mask_time_min_masks = mask_time_min_masks
|
||||
self.mask_feature_prob = mask_feature_prob
|
||||
self.mask_feature_length = mask_feature_length
|
||||
self.mask_feature_min_masks = mask_feature_min_masks
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
|
|
|
@ -307,6 +307,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||
max_length=max_length if max_length else self.n_samples,
|
||||
truncation=truncation,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
)
|
||||
# make sure list is in array format
|
||||
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
|
||||
|
@ -318,6 +319,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||
else:
|
||||
padded_inputs["input_features"] = input_features
|
||||
|
||||
if return_attention_mask:
|
||||
# rescale from sample (48000) to feature (3000)
|
||||
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
|
||||
|
||||
if return_tensors is not None:
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import math
|
|||
import random
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
@ -97,6 +98,126 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
||||
def _compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
min_masks: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
||||
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
||||
CPU as part of the preprocessing during training.
|
||||
|
||||
Args:
|
||||
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
||||
the first element is the batch size and the second element is the length of the axis to span.
|
||||
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
||||
independently generated mask spans of length `mask_length` is computed by
|
||||
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
||||
actual percentage will be smaller.
|
||||
mask_length: size of the mask
|
||||
min_masks: minimum number of masked spans
|
||||
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
||||
each batch dimension.
|
||||
"""
|
||||
batch_size, sequence_length = shape
|
||||
|
||||
if mask_length < 1:
|
||||
raise ValueError("`mask_length` has to be bigger than 0.")
|
||||
|
||||
if mask_length > sequence_length:
|
||||
raise ValueError(
|
||||
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
||||
f" and `sequence_length`: {sequence_length}`"
|
||||
)
|
||||
|
||||
# epsilon is used for probabilistic rounding
|
||||
epsilon = np.random.rand(1).item()
|
||||
|
||||
def compute_num_masked_span(input_length):
|
||||
"""Given input length, compute how many spans should be masked"""
|
||||
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
||||
num_masked_span = max(num_masked_span, min_masks)
|
||||
|
||||
# make sure num masked span <= sequence_length
|
||||
if num_masked_span * mask_length > sequence_length:
|
||||
num_masked_span = sequence_length // mask_length
|
||||
|
||||
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
||||
if input_length - (mask_length - 1) < num_masked_span:
|
||||
num_masked_span = max(input_length - (mask_length - 1), 0)
|
||||
|
||||
return num_masked_span
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
||||
# SpecAugment mask to fill
|
||||
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
||||
spec_aug_mask_idxs = []
|
||||
|
||||
max_num_masked_span = compute_num_masked_span(sequence_length)
|
||||
|
||||
if max_num_masked_span == 0:
|
||||
return spec_aug_mask
|
||||
|
||||
for input_length in input_lengths:
|
||||
# compute num of masked spans for this input
|
||||
num_masked_span = compute_num_masked_span(input_length)
|
||||
|
||||
# get random indices to mask
|
||||
spec_aug_mask_idx = np.random.choice(
|
||||
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
||||
)
|
||||
|
||||
# pick first sampled index that will serve as a dummy index to pad vector
|
||||
# to ensure same dimension for all batches due to probabilistic rounding
|
||||
# Picking first sample just pads those vectors twice.
|
||||
if len(spec_aug_mask_idx) == 0:
|
||||
# this case can only happen if `input_length` is strictly smaller then
|
||||
# `sequence_length` in which case the last token has to be a padding
|
||||
# token which we can use as a dummy mask id
|
||||
dummy_mask_idx = sequence_length - 1
|
||||
else:
|
||||
dummy_mask_idx = spec_aug_mask_idx[0]
|
||||
|
||||
spec_aug_mask_idx = np.concatenate(
|
||||
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
||||
)
|
||||
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
||||
|
||||
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
||||
|
||||
# expand masked indices to masked spans
|
||||
spec_aug_mask_idxs = np.broadcast_to(
|
||||
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
||||
)
|
||||
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
||||
|
||||
# add offset to the starting indexes so that indexes now create a span
|
||||
offsets = np.arange(mask_length)[None, None, :]
|
||||
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
||||
batch_size, max_num_masked_span * mask_length
|
||||
)
|
||||
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
||||
|
||||
# ensure that we cannot have indices larger than sequence_length
|
||||
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
||||
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
|
||||
|
||||
# scatter indices to mask
|
||||
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
||||
|
||||
return spec_aug_mask
|
||||
|
||||
|
||||
class WhisperPositionalEmbedding(nn.Embedding):
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
|
@ -503,6 +624,14 @@ WHISPER_INPUTS_DOCSTRING = r"""
|
|||
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
|
||||
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
|
||||
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing *SpecAugment* data augmentation on padding token indices. Mask values selected in
|
||||
`[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
|
||||
|
@ -999,11 +1128,55 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||
"""
|
||||
self.encoder._freeze_parameters()
|
||||
|
||||
def _mask_input_features(
|
||||
self,
|
||||
input_features: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
"""
|
||||
Masks extracted features along time axis and/or along feature axis according to
|
||||
[SpecAugment](https://arxiv.org/abs/1904.08779).
|
||||
"""
|
||||
|
||||
# `config.apply_spec_augment` can set masking to False
|
||||
if not getattr(self.config, "apply_spec_augment", True):
|
||||
return input_features
|
||||
|
||||
# generate indices & apply SpecAugment along time axis
|
||||
batch_size, hidden_size, sequence_length = input_features.size()
|
||||
|
||||
if self.config.mask_time_prob > 0 and self.training:
|
||||
# generate indices & apply SpecAugment along time axis
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
(batch_size, sequence_length),
|
||||
mask_prob=self.config.mask_time_prob,
|
||||
mask_length=self.config.mask_time_length,
|
||||
attention_mask=attention_mask,
|
||||
min_masks=self.config.mask_time_min_masks,
|
||||
)
|
||||
mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
|
||||
mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
|
||||
input_features[mask_time_indices] = 0
|
||||
|
||||
if self.config.mask_feature_prob > 0 and self.training:
|
||||
# generate indices & apply SpecAugment along feature axis
|
||||
mask_feature_indices = _compute_mask_indices(
|
||||
(batch_size, hidden_size),
|
||||
mask_prob=self.config.mask_feature_prob,
|
||||
mask_length=self.config.mask_feature_length,
|
||||
min_masks=self.config.mask_feature_min_masks,
|
||||
)
|
||||
mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
|
||||
input_features[mask_feature_indices] = 0
|
||||
|
||||
return input_features
|
||||
|
||||
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_features: Optional[torch.LongTensor] = None,
|
||||
input_features: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
|
@ -1044,6 +1217,8 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if encoder_outputs is None:
|
||||
input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
input_features,
|
||||
head_mask=head_mask,
|
||||
|
@ -1139,7 +1314,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_features: Optional[torch.LongTensor] = None,
|
||||
input_features: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
|
@ -1193,6 +1369,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
|
||||
outputs = self.model(
|
||||
input_features,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
|
|
|
@ -383,6 +383,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||
|
||||
expected_arg_names = [
|
||||
"input_features",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
]
|
||||
|
@ -909,6 +910,34 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
def test_mask_feature_prob(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.mask_feature_prob = 0.2
|
||||
config.mask_feature_length = 2
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# forward pass
|
||||
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
|
||||
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
|
||||
|
||||
def test_mask_time_prob(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.mask_time_prob = 0.2
|
||||
config.mask_time_length = 2
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# forward pass
|
||||
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
|
||||
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
|
@ -1289,3 +1318,38 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_specaugment_librispeech(self):
|
||||
torch_device = "cpu"
|
||||
set_seed(0)
|
||||
# Apply SpecAugment
|
||||
model = WhisperModel.from_pretrained("openai/whisper-tiny", apply_spec_augment=True)
|
||||
# Set model to training mode to enable SpecAugment
|
||||
model.train()
|
||||
model.to(torch_device)
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(
|
||||
input_features,
|
||||
decoder_input_ids=torch.tensor([[50258, 50259, 50359]]),
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=False,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor(
|
||||
[
|
||||
0.9362, -4.7105, 5.0879, 3.9642, 1.0013, -6.0096, 4.7285, -3.1847,
|
||||
-0.8648, 1.9631, 6.2653, 3.6936, 0.3575, -4.5818, 3.0564, 7.8712,
|
||||
2.9951, 0.6848, 9.9497, -2.6638, 1.1571, -6.8546, -1.4333, -7.7584,
|
||||
1.1200, 3.9030, 4.4655, -4.4919, -1.1703, 9.6241
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
|
|
Loading…
Reference in New Issue