diff --git a/docs/source/index.rst b/docs/source/index.rst index 4f466878c4..9c0db9f120 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -411,7 +411,7 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ -| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ❌ | +| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | XLM | ✅ | ❌ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ diff --git a/docs/source/model_doc/wav2vec2.rst b/docs/source/model_doc/wav2vec2.rst index dd3af77b52..df92a06386 100644 --- a/docs/source/model_doc/wav2vec2.rst +++ b/docs/source/model_doc/wav2vec2.rst @@ -99,3 +99,23 @@ TFWav2Vec2ForCTC .. autoclass:: transformers.TFWav2Vec2ForCTC :members: call + + +FlaxWav2Vec2Model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxWav2Vec2Model + :members: __call__ + + +FlaxWav2Vec2ForCTC +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxWav2Vec2ForCTC + :members: __call__ + +FlaxWav2Vec2ForPreTraining +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxWav2Vec2ForPreTraining + :members: __call__ diff --git a/examples/research_projects/jax-projects/wav2vec2/README.md b/examples/research_projects/jax-projects/wav2vec2/README.md new file mode 100644 index 0000000000..f41e605585 --- /dev/null +++ b/examples/research_projects/jax-projects/wav2vec2/README.md @@ -0,0 +1,119 @@ +# Wav2Vec2 Contrastive Loss PreTraining examples + +The following example showcases how to pretrain a wav2vec2 model using the JAX/Flax backend. +Pretraining Wav2Vec2 is rather complex, so it is highly recommended to read the +[official paper](https://arxiv.org/abs/2006.11477). + +JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. +Models written in JAX/Flax are **immutable** and updated in a purely functional +way which enables simple and efficient model parallelism. + +`run_wav2vec2_pretrain_flax.py` is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets library or use your own files (jsonlines or csv), then pretrain the wav2vec2 architectures above on it. + +For custom datasets in `jsonlines` format please see: [the Datasets documentation](https://huggingface.co/docs/datasets/loading_datasets.html#json-files) and you also will find examples of these below. + +Let's start by creating a model repository to save the trained model and logs. +Here we call the model `"wav2vec2-base-robust"`, but you can change the model name as you like. + +You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that +you are logged in) or via the command line: + +``` +huggingface-cli repo create wav2vec2-base-robust +``` + +Next we clone the model repository to add the tokenizer and model files. + +``` +git clone https://huggingface.co//wav2vec2-base-robust +``` + +To ensure that all tensorboard traces will be uploaded correctly, we need to +track them. You can run the following command inside your model repo to do so. + +``` +cd wav2vec2-base-robust +git lfs track "*tfevents*" +``` + +Great, we have set up our model repository. During training, we will automatically +push the training logs and model weights to the repo. + +Next, let's add a symbolic link to the `run_wav2vec2_pretrain_flax`. + +```bash +export MODEL_DIR="./wav2vec2-base-robust" +ln -s ~/transformers/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py ./ +``` + +### Create the model configuration + +Let's first create the model configuration and store it in the model repository. +Note that many training parameters can be set in the model configuration including +the configuration about the masking distribution (`mask_time_length`, `mask_time_prob`), +dropout (`attention_dropout`, ...), the trade-off between the contrastive loss and +the diversity loss, etc... +Mostly likely you will need to change these parameters depending on your use case. +Again, we highly recommend to read the [official paper](https://arxiv.org/abs/2006.11477) +to better understand which parameters can be set for pretraining. + +For this example, we will be using a `"base"`-sized model of Wav2Vec2 with robust +layer norm and keep most of the default settings. + +```python +model_dir="./wav2vec2-base-robust" + +from transformers import Wav2Vec2Config +config = Wav2Vec2Config.from_pretrained( + "facebook/wav2vec2-base", + mask_time_length=10, + mask_time_prob=0.05, + diversity_loss_weight=0.1, + num_negatives=100, + do_stable_layer_norm=True, + feat_extract_norm="layer", +) +config.save_pretrained(model_dir) +``` + +### Create a feature extractor configuration + +Before we can start the training, we need to define +a feature extractor that takes care of normalization, etc... + +Here we can also re-use the feature extractor of [wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base) while making sure that padding is allowed. + + +```python +model_dir="./wav2vec2-base-robust" + +from transformers import Wav2Vec2FeatureExtractor +config = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base", return_attention_mask=True) +config.save_pretrained(model_dir) +``` + +### Train the model +Finally, we can run the example script to train the model: + +```bash +./run_wav2vec2_pretrain_flax.py \ + --output_dir=${MODEL_DIR} \ + --num_train_epochs="5" \ + --per_device_train_batch_size="32" \ + --per_device_eval_batch_size="32" \ + --learning_rate="5e-4" \ + --weight_decay="0.01" \ + --warmup_steps="2000" \ + --model_name_or_path=${MODEL_DIR} \ + --dataset_name="librispeech_asr" \ + --dataset_config_name="clean" \ + --train_split_name="train.100" \ + --preprocessing_num_workers="4" \ + --max_duration_in_seconds="10.0" \ + --adam_beta1="0.9" \ + --adam_beta2="0.98" \ + --push_to_hub +``` + +Note that this script is not fully tested yet, so we cannot ensure that +the above script leads to satisfying results. diff --git a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py new file mode 100755 index 0000000000..e4bad89256 --- /dev/null +++ b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py @@ -0,0 +1,566 @@ +#!/usr/bin/env python3 +import logging +import sys +import time +from dataclasses import field +from pathlib import Path +from typing import Dict, List, Optional, Union + +import numpy as np +from datasets import DatasetDict, load_dataset +from tqdm import tqdm + +import flax +import jax +import jax.numpy as jnp +import librosa +import optax +from flax import jax_utils, traverse_util +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard +from transformers import ( + FlaxWav2Vec2ForPreTraining, + HfArgumentParser, + TrainingArguments, + Wav2Vec2Config, + Wav2Vec2FeatureExtractor, + is_tensorboard_available, +) +from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices, _sample_negative_indices + + +logger = logging.getLogger(__name__) + + +@flax.struct.dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + freeze_feature_extractor: Optional[bool] = field( + default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."} + ) + gradient_checkpointing: Optional[bool] = field( + default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."} + ) + verbose_logging: Optional[bool] = field( + default=False, + metadata={"help": "Whether to log verbose messages or not."}, + ) + max_gumbel_temperature: Optional[float] = field( + default=2.0, metadata={"help": "Maximum temperature for gumbel softmax."} + ) + min_gumbel_temperature: Optional[float] = field( + default=0.1, metadata={"help": "Minimum temperature for gumbel softmax."} + ) + gumbel_temperature_decay: Optional[float] = field( + default=0.999995, metadata={"help": "Decay of gumbel temperature during training."} + ) + + +@flax.struct.dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_split_name: Optional[str] = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + validation_split_name: Optional[str] = field( + default="validation", + metadata={ + "help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'" + }, + ) + speech_file_column: Optional[str] = field( + default="file", + metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"}, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_duration_in_seconds: Optional[float] = field( + default=20.0, metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"} + ) + pad_to_multiple_of: Optional[int] = field( + default=1024, + metadata={ + "help": "If set will pad the sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU" + }, + ) + + +@flax.struct.dataclass +class FlaxDataCollatorForWav2Vec2Pretraining: + """ + Data collator that will dynamically pad the inputs received and prepare masked indices + for self-supervised pretraining. + + Args: + model (:class:`~transformers.FlaxWav2Vec2ForPreTraining`): + The Wav2Vec2 model used for pretraining. The data collator needs to have access + to config and ``_get_feat_extract_output_lengths`` function for correct padding. + feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`): + The processor used for proccessing the data. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + model: FlaxWav2Vec2ForPreTraining + feature_extractor: Wav2Vec2FeatureExtractor + padding: Union[bool, str] = "longest" + pad_to_multiple_of: Optional[int] = None + max_length: Optional[int] = None + + def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]: + # reformat list to dict and set to pytorch format + batch = self.feature_extractor.pad( + features, + max_length=self.max_length, + padding=self.padding, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="np", + ) + mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1]) + + # sample randomly masked indices + batch["mask_time_indices"] = _compute_mask_indices( + (batch["input_values"].shape[0], mask_indices_seq_length), + self.model.config.mask_time_prob, + self.model.config.mask_time_length, + min_masks=2, + ) + + # sample indices to take for negative vectors + batch["sampled_negative_indices"] = _sample_negative_indices( + (batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)), + self.model.config.num_negatives, + ) + + return batch + + +def configure_logger(model_args: ModelArguments, training_args: TrainingArguments): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logging_level = logging.WARNING + if model_args.verbose_logging: + logging_level = logging.DEBUG + logger.setLevel(logging_level) + + +def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + +def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: + num_samples = len(samples_idx) + samples_to_remove = num_samples % batch_size + + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + batch_idx = np.split(samples_idx, sections_split) + return batch_idx + + +def compute_contrastive_loss( + quantized_features, transformer_features, negative_indices, mask_time_indices, logits_temp, num_negatives +): + batch_size, sequence_length, hidden_size = quantized_features.shape + + # take negative vectors from sampled indices + quantized_negatives = quantized_features.reshape(-1, hidden_size)[negative_indices.reshape(-1)] + quantized_negatives = quantized_negatives.reshape( + batch_size, sequence_length, num_negatives, hidden_size + ).transpose(2, 0, 1, 3) + + target_features = jnp.concatenate([quantized_features[None, :], quantized_negatives], axis=0) + loss_logits = optax.cosine_similarity(transformer_features, target_features) + loss_logits = loss_logits / logits_temp + + neg_is_pos = (quantized_features == quantized_negatives).all(-1) + neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0) + + # make sure incorrectly sampled vectors don't contribute to loss + loss_logits = jnp.where(neg_is_pos, -1e9, loss_logits) + + predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0]) + targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten() + + target_mask = jnp.where(targets >= 0, 1.0, 0.0) + contrastive_loss = optax.softmax_cross_entropy(predictions, onehot(targets, predictions.shape[-1])) * target_mask + + contrastive_loss = contrastive_loss.sum() + + return contrastive_loss + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + configure_logger(model_args, training_args) + + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) + + if "validation" not in datasets.keys(): + # make sure only "validation" and "train" keys remain" + datasets = DatasetDict() + datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + ) + datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + ) + else: + # make sure only "validation" and "train" keys remain" + datasets = DatasetDict() + datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split="validation", + cache_dir=model_args.cache_dir, + ) + datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"{data_args.train_split_name}", + cache_dir=model_args.cache_dir, + ) + + # only normalized-inputs-training is supported + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True + ) + + def prepare_dataset(batch): + # check that all files have the correct sampling rate + batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate) + return batch + + # load audio files into numpy arrays + vectorized_datasets = datasets.map( + prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names + ) + + # filter audio files that are too long + vectorized_datasets = vectorized_datasets.filter( + lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate) + ) + + def normalize(batch): + return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate) + + # normalize and transform to `BatchFeatures` + vectorized_datasets = vectorized_datasets.map( + normalize, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + remove_columns=vectorized_datasets["train"].column_names, + ) + + # pretraining is only supported for "newer" stable layer norm architecture + # apply_spec_augment has to be True, mask_feature_prob has to be 0.0 + config = Wav2Vec2Config.from_pretrained( + model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + gradient_checkpointing=model_args.gradient_checkpointing, + ) + + if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": + raise ValueError( + "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'" + ) + + model = FlaxWav2Vec2ForPreTraining( + config, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) + + data_collator = FlaxDataCollatorForWav2Vec2Pretraining( + model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of + ) + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) + + # Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + dropout_rngs = jax.random.split(rng, jax.local_device_count()) + gumbel_rngs = jax.random.split(rng, jax.local_device_count()) + + num_epochs = int(training_args.num_train_epochs) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + + num_train_steps = len(vectorized_datasets["train"]) // train_batch_size * num_epochs + + # Create learning rate schedule + warmup_fn = optax.linear_schedule( + init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps + ) + decay_fn = optax.linear_schedule( + init_value=training_args.learning_rate, + end_value=0, + transition_steps=num_train_steps - training_args.warmup_steps, + ) + linear_decay_lr_schedule_fn = optax.join_schedules( + schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps] + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + flat_mask = { + path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) + for path in flat_params + } + return traverse_util.unflatten_dict(flat_mask) + + # create adam optimizer + adamw = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + # Setup train state and define training hyper-parameters + state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) + num_negatives = model.config.num_negatives + contrastive_logits_temperature = model.config.contrastive_logits_temperature + num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups + diversity_loss_weight = model.config.diversity_loss_weight + + # Define gradient update step fn + def train_step(state, batch, dropout_rng, gumbel_rng): + dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) + gumbel_rng, new_gumbel_rng = jax.random.split(gumbel_rng) + + def loss_fn(params): + negative_indices = batch.pop("sampled_negative_indices") + + gumbel_temperature = jnp.clip( + model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay ** state.step, + a_min=model_args.min_gumbel_temperature, + ) + + outputs = state.apply_fn( + **batch, + gumbel_temperature=gumbel_temperature, + params=params, + dropout_rng=dropout_rng, + gumbel_rng=gumbel_rng, + train=True, + ) + + contrastive_loss = compute_contrastive_loss( + outputs.projected_quantized_states, + outputs.projected_states, + negative_indices, + batch["mask_time_indices"], + contrastive_logits_temperature, + num_negatives, + ) + + diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors + loss = contrastive_loss + diversity_loss_weight * diversity_loss + + return loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grad = grad_fn(state.params) + grad = jax.lax.pmean(grad, "batch") + new_state = state.apply_gradients(grads=grad) + + metrics = jax.lax.pmean( + {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" + ) + + return new_state, metrics, new_dropout_rng, new_gumbel_rng + + # Create parallel version of the train step + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + # Define eval fn + def eval_step(params, batch): + negative_indices = batch.pop("sampled_negative_indices") + + outputs = model(**batch, params=params, train=False) + + contrastive_loss = compute_contrastive_loss( + outputs.projected_quantized_states, + outputs.projected_states, + negative_indices, + batch["mask_time_indices"], + contrastive_logits_temperature, + num_negatives, + ) + + diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors + loss = contrastive_loss + diversity_loss_weight * diversity_loss + + # summarize metrics + metrics = {"loss": loss.mean(), "codevector_perplexity": outputs.codevector_perplexity} + metrics = jax.lax.pmean(metrics, axis_name="batch") + + return metrics + + p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,)) + + # Replicate the train state on each device + state = jax_utils.replicate(state) + + train_time = 0 + epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + for epoch in epochs: + # ======================== Training ================================ + train_start = time.time() + train_metrics = [] + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + + # Generate an epoch by shuffling sampling indices from the train dataset + num_train_samples = len(vectorized_datasets["train"]) + train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) + train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) + + # Gather the indexes for creating the batch and do a training step + for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): + samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples) + model_inputs = shard(model_inputs.data) + + # Model forward + state, train_metric, dropout_rngs, gumbel_rngs = p_train_step( + state, model_inputs, dropout_rngs, gumbel_rngs + ) + train_metrics.append(train_metric) + + train_time += time.time() - train_start + + epochs.write( + f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" + ) + + # ======================== Evaluating ============================== + num_eval_samples = len(vectorized_datasets["validation"]) + eval_samples_idx = jnp.arange(num_eval_samples) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + + eval_metrics = [] + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [vectorized_datasets["validation"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples) + + # Model forward + model_inputs = shard(model_inputs.data) + metrics = p_eval_step(state.params, model_inputs) + eval_metrics.append(metrics) + + # get eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + + # Update progress bar + epochs.write( + f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})" + ) + + # Save metrics + if has_tensorboard and jax.process_index() == 0: + cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size) + write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) + + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8b339622a7..2f2a35c75d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1643,6 +1643,9 @@ if is_flax_available(): ) _import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) + _import_structure["models.wav2vec2"].extend( + ["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"] + ) else: from .utils import dummy_flax_objects @@ -3023,6 +3026,12 @@ if TYPE_CHECKING: ) from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel + from .models.wav2vec2 import ( + FlaxWav2Vec2ForCTC, + FlaxWav2Vec2ForPreTraining, + FlaxWav2Vec2Model, + FlaxWav2Vec2PreTrainedModel, + ) else: # Import the same objects as dummies to get them in the namespace. # They will raise an import error if the user tries to instantiate / use them. diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 8ba020615a..0e01f19fb9 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -64,6 +64,7 @@ from ..roberta.modeling_flax_roberta import ( ) from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel +from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model from .auto_factory import auto_class_factory from .configuration_auto import ( BartConfig, @@ -75,6 +76,7 @@ from .configuration_auto import ( RobertaConfig, T5Config, ViTConfig, + Wav2Vec2Config, ) @@ -93,6 +95,7 @@ FLAX_MODEL_MAPPING = OrderedDict( (CLIPConfig, FlaxCLIPModel), (ViTConfig, FlaxViTModel), (T5Config, FlaxT5Model), + (Wav2Vec2Config, FlaxWav2Vec2Model), ] ) @@ -105,6 +108,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( (BartConfig, FlaxBartForConditionalGeneration), (ElectraConfig, FlaxElectraForPreTraining), (T5Config, FlaxT5ForConditionalGeneration), + (Wav2Vec2Config, FlaxWav2Vec2ForPreTraining), ] ) diff --git a/src/transformers/models/wav2vec2/__init__.py b/src/transformers/models/wav2vec2/__init__.py index aaa5a5d29a..75dc4cbd91 100644 --- a/src/transformers/models/wav2vec2/__init__.py +++ b/src/transformers/models/wav2vec2/__init__.py @@ -17,7 +17,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available +from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, is_torch_available _import_structure = { @@ -37,7 +37,6 @@ if is_torch_available(): "Wav2Vec2PreTrainedModel", ] - if is_tf_available(): _import_structure["modeling_tf_wav2vec2"] = [ "TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -46,6 +45,14 @@ if is_tf_available(): "TFWav2Vec2PreTrainedModel", ] +if is_flax_available(): + _import_structure["modeling_flax_wav2vec2"] = [ + "FlaxWav2Vec2ForCTC", + "FlaxWav2Vec2ForPreTraining", + "FlaxWav2Vec2Model", + "FlaxWav2Vec2PreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config @@ -71,6 +78,14 @@ if TYPE_CHECKING: TFWav2Vec2PreTrainedModel, ) + if is_flax_available(): + from .modeling_tf_wav2vec2 import ( + FlaxWav2Vec2ForCTC, + FlaxWav2Vec2ForPreTraining, + FlaxWav2Vec2Model, + FlaxWav2Vec2PreTrainedModel, + ) + else: import importlib diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py new file mode 100644 index 0000000000..12764a40ac --- /dev/null +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -0,0 +1,1216 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax Wav2Vec2 model. """ + +from functools import partial +from typing import Optional, Tuple, Union + +import numpy as np + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from flax.linen.attention import dot_product_attention_weights +from jax import lax + +from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel +from ...utils import logging +from .configuration_wav2vec2 import Wav2Vec2Config + + +logger = logging.get_logger(__name__) + + +@flax.struct.dataclass +class FlaxWav2Vec2BaseModelOutput(ModelOutput): + """ + Output type of :class:`~transformers.FlaxWav2Vec2BaseModelOutput`, with potential hidden states and attentions. + + Args: + last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + extract_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, last_conv_dim)`): + Sequence of extracted feature vectors of the last convolutional layer of the model with ``last_conv_dim`` + being the dimension of the last convolutional layer. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: jnp.ndarray = None + extract_features: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxWav2Vec2ForPreTrainingOutput(ModelOutput): + """ + Output type of :class:`~transformers.FlaxWav2Vec2ForPreTrainingOutput`, with potential hidden states and + attentions. + + Args: + loss (`optional`, returned when model is in train mode, ``jnp.ndarray`` of shape :obj:`(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the `official + paper `__ . (classification) loss. + projected_states (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to `config.proj_codevector_dim` that can be used to predict the masked + projected quantized states. + projected_quantized_states (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to `config.proj_codevector_dim` representing the positive + target vectors for contrastive loss. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + projected_states: jnp.ndarray = None + projected_quantized_states: jnp.ndarray = None + codevector_perplexity: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for + ASR `__. Note that this method is not optimized to run on TPU and should be run + on CPU as part of the preprocessing during training. + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_length: size of the mask + min_masks: minimum number of masked spans + + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" + ) + + # compute number of masked spans in batch + num_masked_spans = int(mask_prob * sequence_length / mask_length + np.random.rand(1).item()) + num_masked_spans = max(num_masked_spans, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_spans * mask_length > sequence_length: + num_masked_spans = sequence_length // mask_length + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool) + + # get random indices to mask + spec_aug_mask_idxs = np.array( + [ + np.random.choice(np.arange(sequence_length - (mask_length - 1)), num_masked_spans, replace=False) + for _ in range(batch_size) + ] + ) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to(spec_aug_mask_idxs[:, :, None], (batch_size, num_masked_spans, mask_length)) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, num_masked_spans * mask_length) + + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, num_masked_spans, mask_length)).reshape( + batch_size, num_masked_spans * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +def _sample_negative_indices(features_shape: Tuple, num_negatives: int): + """ + Sample `num_negatives` vectors from feature vectors. + """ + batch_size, sequence_length, hidden_size = features_shape + if sequence_length <= 1: + raise ValueError( + f"`features should have `sequence_length` > 1, but are of shape " + f"(batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})." + ) + + # get `num_negatives` random vector indices from the same utterance + sampled_negative_indices = np.random.randint( + low=0, + high=sequence_length - 1, + size=(batch_size, num_negatives * sequence_length), + ) + + # generate indices of the positive vectors themselves, repeat them `num_negatives` times + feature_indices = np.broadcast_to(np.arange(sequence_length)[:, None], (sequence_length, num_negatives)).flatten() + + # avoid sampling the same positive vector, but keep the distribution uniform + sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1 + + # correct for batch size + for batch_idx in range(1, batch_size): + sampled_negative_indices[batch_idx] += batch_idx * sequence_length + + return sampled_negative_indices + + +WAV_2_VEC_2_START_DOCSTRING = r""" + Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations + `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. + + This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the + generic methods the library implements for all its model (such as downloading or saving, resizing the input + embeddings, pruning heads etc.) + + This model is also a Flax Linen `flax.nn.Module + `__ subclass. Use it as a regular Flax + Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - `Just-In-Time (JIT) compilation `__ + - `Automatic Differentiation `__ + - `Vectorization `__ + - `Parallelization `__ + + Parameters: + config (:class:`~transformers.Wav2Vec2Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the + model weights. +""" + + +WAV_2_VEC_2_INPUTS_DOCSTRING = r""" + Args: + input_values (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the :class:`~transformers.Wav2Vec2Processor` should + be used for padding and conversion into a tensor of type `jnp.ndarray`. See + :meth:`transformers.Wav2Vec2Processor.__call__` for details. + attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in ``[0, + 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ .. warning:: :obj:`attention_mask` should + only be passed if the corresponding processor has ``config.return_attention_mask == True``. For all models + whose processor has ``config.return_attention_mask == False``, such as `wav2vec2-base + `__, :obj:`attention_mask` should **not** be passed to + avoid degraded performance when doing batched inference. For such models :obj:`input_values` should simply + be padded with 0 and passed without :obj:`attention_mask`. Be aware that these models also yield slightly + different results depending on whether :obj:`input_values` is padded or not. + mask_time_indices (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in `config.proj_codevector_dim` space. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +class FlaxWav2Vec2LayerNormConvLayer(nn.Module): + config: Wav2Vec2Config + layer_id: int = 0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1 + self.out_conv_dim = self.config.conv_dim[self.layer_id] + + self.conv = nn.Conv( + features=self.config.conv_dim[self.layer_id], + kernel_size=self.config.conv_kernel[self.layer_id], + strides=(self.config.conv_stride[self.layer_id],), + use_bias=self.config.conv_bias, + kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype), + padding="VALID", + dtype=self.dtype, + ) + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.activation = ACT2FN[self.config.feat_extract_activation] + + def __call__(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class FlaxConvWithWeightNorm(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + features=self.config.hidden_size, + kernel_size=self.config.num_conv_pos_embeddings, + kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype), + padding="VALID", + feature_group_count=self.config.num_conv_pos_embedding_groups, + dtype=self.dtype, + ) + weight_shape = ( + self.conv.features, + self.conv.features // self.conv.feature_group_count, + self.conv.kernel_size, + ) + self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(dtype=self.dtype), weight_shape) + self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]) + self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,)) + self.prev_padding = self.conv.kernel_size // 2 + + def _get_normed_weights(self): + weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :] + normed_weight_v = jnp.divide(self.weight_v, weight_v_norm) + normed_kernel = jnp.multiply(normed_weight_v, self.weight_g) + return normed_kernel + + def __call__(self, hidden_states): + kernel = self._get_normed_weights() + hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0))) + hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states) + return hidden_states + + +class FlaxWav2Vec2PositionalConvEmbedding(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype) + self.activation = ACT2FN[self.config.feat_extract_activation] + self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0 + + def __call__(self, hidden_states): + hidden_states = hidden_states.transpose((0, 1, 2)) + + hidden_states = self.conv(hidden_states) + + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, : -self.num_pad_remove, :] + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose((0, 1, 2)) + return hidden_states + + +class FlaxConvLayersCollection(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + if self.config.feat_extract_norm == "layer": + self.layers = [ + FlaxWav2Vec2LayerNormConvLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) + for i in range(self.config.num_feat_extract_layers) + ] + elif self.config.feat_extract_norm == "group": + raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported") + else: + raise ValueError( + f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + + def __call__(self, hidden_states): + for i, conv_layer in enumerate(self.layers): + hidden_states = conv_layer(hidden_states) + return hidden_states + + +class FlaxWav2Vec2FeatureExtractor(nn.Module): + """Construct the featurs from raw audio waveform""" + + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype) + + def __call__(self, input_values): + hidden_states = input_values[:, :, None] + hidden_states = self.conv_layers(hidden_states) + return hidden_states + + +class FlaxWav2Vec2FeatureProjection(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.projection = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout) + + def __call__(self, hidden_states, deterministic=True): + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states, norm_hidden_states + + +class FlaxWav2Vec2Attention(nn.Module): + config: Wav2Vec2Config + embed_dim: int + num_heads: int + dropout: float = 0.0 + bias: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + ) + + dense = partial( + nn.Dense, + self.embed_dim, + use_bias=self.bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + + self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() + self.out_proj = dense() + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + """Input shape: Batch x Time x Channel""" + + # get query proj + query_states = self.q_proj(hidden_states) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + if attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class FlaxWav2Vec2FeedForward(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout) + + self.intermediate_dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + if isinstance(self.config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[self.config.hidden_act] + else: + self.intermediate_act_fn = self.config.hidden_act + + self.output_dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.attention = FlaxWav2Vec2Attention( + config=self.config, + embed_dim=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout) + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype) + self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False): + attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + hidden_states, attn_weights = self.attention( + hidden_states, attention_mask=attention_mask, deterministic=deterministic + ) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = attn_residual + hidden_states + hidden_states = hidden_states + self.feed_forward( + self.final_layer_norm(hidden_states), deterministic=deterministic + ) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = [ + FlaxWav2Vec2EncoderLayerStableLayerNorm(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states,) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxWav2Vec2StableLayerNormEncoder(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype) + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout) + self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask=None, + deterministic=True, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + + if attention_mask is not None: + # make sure padded tokens are not attended to + hidden_states = jnp.where( + jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0 + ) + + position_embeddings = self.pos_conv_embed(hidden_states) + + hidden_states = hidden_states + position_embeddings + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = self.layer_norm(outputs[0]) + + if not return_dict: + return (hidden_states,) + outputs[1:] + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module): + """ + Vector quantization using gumbel softmax. See `CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX + `__ for more information. + """ + + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.num_groups = self.config.num_codevector_groups + self.num_vars = self.config.num_codevectors_per_group + + if self.config.codevector_dim % self.num_groups != 0: + raise ValueError( + f"`config.codevector_dim {self.config.codevector_dim} must be divisible by" + f" `config.num_codevector_groups` {self.num_groups} for concatenation" + ) + + # storage for codebook variables (codewords) + self.codevectors = self.param( + "codevectors", + jax.nn.initializers.uniform(), + (1, self.num_groups * self.num_vars, self.config.codevector_dim // self.num_groups), + ) + self.weight_proj = nn.Dense( + self.num_groups * self.num_vars, + kernel_init=jax.nn.initializers.normal(1.0, self.dtype), + dtype=self.dtype, + ) + + @staticmethod + def _compute_perplexity(probs, mask=None): + if mask is not None: + mask_extended = jnp.broadcast_to(mask.flatten()[:, None, None], probs.shape) + probs = jnp.where(mask_extended, probs, jnp.zeros_like(probs)) + marginal_probs = probs.sum(axis=0) / mask.sum() + else: + marginal_probs = probs.mean(axis=0) + + perplexity = jnp.exp(-jnp.sum(marginal_probs * jnp.log(marginal_probs + 1e-7), axis=-1)).sum() + return perplexity + + def __call__(self, hidden_states, mask_time_indices=None, deterministic=True, temperature=1): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.reshape(batch_size * sequence_length * self.num_groups, -1) + + if not deterministic: + # sample code vector probs via gumbel in differentiateable way + gumbel_rng = self.make_rng("gumbel") + gumbels = jax.random.gumbel(gumbel_rng, hidden_states.shape) + codevector_probs = nn.softmax((hidden_states + gumbels) / temperature) + + # compute perplexity + codevector_soft_dist = nn.softmax( + hidden_states.reshape(batch_size * sequence_length, self.num_groups, -1), axis=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(axis=-1) + codevector_probs = jax.nn.one_hot(codevector_idx, hidden_states.shape[-1]) * 1.0 + codevector_probs = codevector_probs.reshape(batch_size * sequence_length, self.num_groups, -1) + perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) + + codevector_probs = codevector_probs.reshape(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = jnp.expand_dims(codevector_probs, axis=-1) * self.codevectors + codevectors = codevectors_per_group.reshape(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).reshape(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Wav2Vec2Config + base_model_prefix: str = "wav2vec2" + module_class: nn.Module = None + + def __init__( + self, + config: Wav2Vec2Config, + input_shape: Tuple = (1, 1024), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + # init input tensors + input_values = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_values) + params_rng, dropout_rng = jax.random.split(rng, 2) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"] + + @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + def __call__( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + batch_size, sequence_length = input_values.shape + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + return self.module.apply( + inputs, + jnp.array(input_values, dtype="f4"), + jnp.array(attention_mask, dtype="i4"), + mask_time_indices, + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): + return self.module._get_feat_extract_output_lengths(input_lengths) + + +class FlaxWav2Vec2Module(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.feature_extractor = FlaxWav2Vec2FeatureExtractor(self.config, dtype=self.dtype) + self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype) + self.masked_spec_embed = self.param( + "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,) + ) + + if self.config.do_stable_layer_norm: + self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype) + else: + raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.") + + def __call__( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + deterministic=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + """ + + Returns: + + Example:: + + >>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2Model + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") + >>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") + + >>> def map_to_array(batch): + >>> speech, _ = sf.read(batch["file"]) + >>> batch["speech"] = speech + >>> return batch + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1 + >>> hidden_states = model(input_values).last_hidden_state + + """ + extract_features = self.feature_extractor(input_values) + + if attention_mask is not None: + # compute real output lengths according to convolution formula + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4")) + + attention_mask = jnp.zeros(extract_features.shape[:2], dtype=self.dtype) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + attention_mask = jax.ops.index_update( + attention_mask, jax.ops.index[jnp.arange(attention_mask.shape[0]), output_lengths - 1], 1 + ) + attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") + + hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic) + if mask_time_indices is not None: # apply SpecAugment along time axis with given indices + hidden_states = jnp.where( + jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape), + jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape), + hidden_states, + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return FlaxWav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + +@add_start_docstrings( + "The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.", + WAV_2_VEC_2_START_DOCSTRING, +) +class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel): + module_class = FlaxWav2Vec2Module + + +class FlaxWav2Vec2ForCTCModule(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.final_dropout) + self.lm_head = nn.Dense( + self.config.vocab_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + + def __call__( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + deterministic=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> import jax.numpy as jnp + >>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") + >>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") + + >>> def map_to_array(batch): + >>> speech, _ = sf.read(batch["file"]) + >>> batch["speech"] = speech + >>> return batch + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1 + >>> logits = model(input_values).logits + >>> predicted_ids = jnp.argmax(logits, axis=-1) + + >>> transcription = processor.decode(predicted_ids[0]) + >>> # should give: "A MAN SAID TO THE UNIVERSE SIR I EXIST" + + """ + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + mask_time_indices=mask_time_indices, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + logits = self.lm_head(hidden_states) + + if not return_dict: + return (logits,) + outputs[2:] + + return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + +@add_start_docstrings( + "Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).", + WAV_2_VEC_2_START_DOCSTRING, +) +class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel): + module_class = FlaxWav2Vec2ForCTCModule + + +class FlaxWav2Vec2ForPreTrainingModule(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype) + self.dropout_features = nn.Dropout(self.config.feat_quantizer_dropout) + + self.quantizer = FlaxWav2Vec2GumbelVectorQuantizer(self.config, dtype=self.dtype) + self.project_q = nn.Dense( + self.config.proj_codevector_dim, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.project_hid = nn.Dense( + self.config.proj_codevector_dim, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + + def __call__( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + gumbel_temperature: int = 1, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> import optax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from transformers import Wav2Vec2FeatureExtractor, FlaxWav2Vec2ForPreTraining + >>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base") + >>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base") + + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values # Batch size 1 + + >>> # compute masked indices + >>> batch_size, raw_sequence_length = input_values.shape + >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) + >>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2) + + >>> outputs = model(input_values, mask_time_indices=mask_time_indices) + + >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states) + >>> cosine_sim = optax.cosine_similarity( + ... outputs.projected_states, outputs.projected_quantized_states, axis=-1 + ... ) + + >>> # show that cosine similarity is much higher than random + >>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5 + + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + mask_time_indices=mask_time_indices, + deterministic=deterministic, + return_dict=return_dict, + ) + + # project all transformed features (including masked) to final vq dim + transformer_features = self.project_hid(outputs[0]) + + # quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(outputs[1], deterministic=deterministic) + quantized_features, codevector_perplexity = self.quantizer( + extract_features, mask_time_indices, deterministic=deterministic, temperature=gumbel_temperature + ) + quantized_features = self.project_q(quantized_features) + + if not return_dict: + return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + + return FlaxWav2Vec2ForPreTrainingOutput( + projected_states=transformer_features, + projected_quantized_states=quantized_features, + codevector_perplexity=codevector_perplexity, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + +@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top. """, WAV_2_VEC_2_START_DOCSTRING) +class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel): + module_class = FlaxWav2Vec2ForPreTrainingModule + + @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + # overwrite since has `gumbel_temperature` input + def __call__( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + gumbel_temperature: int = 1, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + gumbel_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + batch_size, sequence_length = input_values.shape + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + if gumbel_rng is not None: + rngs["gumbel"] = gumbel_rng + + inputs = {"params": params or self.params} + + return self.module.apply( + inputs, + jnp.array(input_values, dtype="f4"), + jnp.array(attention_mask, dtype="i4"), + mask_time_indices, + gumbel_temperature, + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index e4a56113d2..a24fb4c9e7 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -654,3 +654,31 @@ class FlaxViTPreTrainedModel: @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) + + +class FlaxWav2Vec2ForCTC: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxWav2Vec2ForPreTraining: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxWav2Vec2Model: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxWav2Vec2PreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) diff --git a/tests/test_modeling_flax_wav2vec2.py b/tests/test_modeling_flax_wav2vec2.py new file mode 100644 index 0000000000..9b33a1d2ba --- /dev/null +++ b/tests/test_modeling_flax_wav2vec2.py @@ -0,0 +1,398 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +import unittest + +import numpy as np + +from transformers import Wav2Vec2Config, is_flax_available +from transformers.testing_utils import require_datasets, require_flax, require_soundfile, slow + +from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, random_attention_mask + + +if is_flax_available(): + import jax + import jax.numpy as jnp + import optax + from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor + from transformers.models.wav2vec2.modeling_flax_wav2vec2 import ( + FlaxWav2Vec2ForCTC, + FlaxWav2Vec2ForPreTraining, + FlaxWav2Vec2GumbelVectorQuantizer, + FlaxWav2Vec2Model, + _compute_mask_indices, + _sample_negative_indices, + ) + + +class FlaxWav2Vec2ModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=1024, # speech is longer + is_training=False, + hidden_size=24, + feat_extract_norm="layer", + feat_extract_dropout=0.0, + feat_extract_activation="gelu", + conv_dim=(32, 32, 32), + conv_stride=(4, 4, 4), + conv_kernel=(8, 8, 8), + conv_bias=False, + num_conv_pos_embeddings=16, + num_conv_pos_embedding_groups=2, + num_hidden_layers=4, + num_attention_heads=2, + hidden_dropout_prob=0.1, # this is most likely not correctly set yet + intermediate_size=20, + layer_norm_eps=1e-5, + hidden_act="gelu", + initializer_range=0.02, + vocab_size=32, + do_stable_layer_norm=True, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.hidden_size = hidden_size + self.feat_extract_norm = feat_extract_norm + self.feat_extract_dropout = feat_extract_dropout + self.feat_extract_activation = feat_extract_activation + self.conv_dim = conv_dim + self.conv_stride = conv_stride + self.conv_kernel = conv_kernel + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_dropout_prob = hidden_dropout_prob + self.intermediate_size = intermediate_size + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.vocab_size = vocab_size + self.do_stable_layer_norm = do_stable_layer_norm + self.scope = scope + + output_seq_length = self.seq_length + for kernel, stride in zip(self.conv_kernel, self.conv_stride): + output_seq_length = (output_seq_length - (kernel - 1)) / stride + self.output_seq_length = int(math.ceil(output_seq_length)) + self.encoder_seq_length = self.output_seq_length + + def prepare_config_and_inputs(self): + input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = random_attention_mask([self.batch_size, self.seq_length]) + + config = Wav2Vec2Config( + do_stable_layer_norm=self.do_stable_layer_norm, + hidden_size=self.hidden_size, + feat_extract_norm=self.feat_extract_norm, + feat_extract_dropout=self.feat_extract_dropout, + feat_extract_activation=self.feat_extract_activation, + conv_dim=self.conv_dim, + conv_stride=self.conv_stride, + conv_kernel=self.conv_kernel, + conv_bias=self.conv_bias, + num_conv_pos_embeddings=self.num_conv_pos_embeddings, + num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + hidden_dropout_prob=self.hidden_dropout_prob, + intermediate_size=self.intermediate_size, + layer_norm_eps=self.layer_norm_eps, + hidden_act=self.hidden_act, + initializer_range=self.initializer_range, + vocab_size=self.vocab_size, + ) + + return config, input_values, attention_mask + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_values, attention_mask = config_and_inputs + inputs_dict = {"input_values": input_values, "attention_mask": attention_mask} + return config, inputs_dict + + +@require_flax +class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): + all_model_classes = ( + (FlaxWav2Vec2Model, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForPreTraining) if is_flax_available() else () + ) + + def setUp(self): + self.model_tester = FlaxWav2Vec2ModelTester(self) + + def test_train(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + input_values = inputs_dict["input_values"] + attention_mask = inputs_dict["attention_mask"] + + model = FlaxWav2Vec2ForPreTraining(config) + + features_shape = ( + input_values.shape[0], + model._get_feat_extract_output_lengths(np.array(input_values.shape[1])), + ) + + batch_size, sequence_length = features_shape[:2] + + mask_prob = 0.5 + mask_length = 4 + mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) + + dropout_rng, gumbel_rng = jax.random.split(jax.random.PRNGKey(0)) + + output = model( + input_values, + attention_mask=attention_mask, + mask_time_indices=mask_time_indices, + train=True, + dropout_rng=dropout_rng, + gumbel_rng=gumbel_rng, + )[0] + + self.assertTrue(output.shape == (batch_size, sequence_length, model.config.proj_codevector_dim)) + + # overwrite because of `input_values` + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.__call__) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["input_values", "attention_mask"] + self.assertListEqual(arg_names[:2], expected_arg_names) + + @slow + # overwrite because of `input_values` + def test_jit_compilation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def model_jitted(input_values, attention_mask=None, **kwargs): + return model(input_values=input_values, attention_mask=attention_mask, **kwargs) + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + + self.assertEqual(jitted_output.shape, output.shape) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + model = model_class_name.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True) + outputs = model(np.ones((1, 1024), dtype="f4")) + self.assertIsNotNone(outputs) + + +@require_flax +class FlaxWav2Vec2UtilsTest(unittest.TestCase): + def test_compute_mask_indices(self): + batch_size = 4 + sequence_length = 60 + mask_prob = 0.5 + mask_length = 1 + + mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) + + self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)]) + + def test_compute_mask_indices_overlap(self): + batch_size = 4 + sequence_length = 80 + mask_prob = 0.5 + mask_length = 4 + + mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) + + # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal + for batch_sum in mask.sum(axis=-1): + self.assertTrue(int(batch_sum) <= mask_prob * sequence_length) + + def test_compute_perplexity(self): + probs = np.arange(100).reshape(2, 5, 10) / 100 + + ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs) + self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3) + + # mask half of the input + mask = np.ones((2,), dtype=np.bool) + mask[0] = 0 + + ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask) + self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3) + + def test_sample_negatives(self): + batch_size = 2 + sequence_length = 10 + hidden_size = 4 + num_negatives = 3 + + features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape( + sequence_length, hidden_size + ) # each value in vector consits of same value + features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size)) + + negative_indices = _sample_negative_indices(features.shape, num_negatives) + + features = features.reshape(-1, hidden_size) # BTC => (BxT)C + # take negative vectors from sampled indices + sampled_negatives = features[negative_indices.reshape(-1)] + negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose( + 2, 0, 1, 3 + ) + + self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size)) + + # make sure no negatively sampled vector is actually a positive one + for negative in negatives: + self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0) + + # make sure that full vectors are sampled and not values of vectors + # => this means that `unique()` yields a single value for `hidden_size` dim + self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1)) + + +@require_flax +@require_datasets +@require_soundfile +@slow +class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase): + def _load_datasamples(self, num_samples): + from datasets import load_dataset + + import soundfile as sf + + ids = [f"1272-141231-000{i}" for i in range(num_samples)] + + # map files to raw + def map_to_array(batch): + speech, _ = sf.read(batch["file"]) + batch["speech"] = speech + return batch + + ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + + ds = ds.filter(lambda x: x["id"] in ids).sort("id").map(map_to_array) + + return ds["speech"][:num_samples] + + def test_inference_ctc_robust_batched(self): + model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True) + processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True) + + input_speech = self._load_datasamples(4) + + inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True) + + input_values = inputs.input_values + attention_mask = inputs.attention_mask + + logits = model(input_values, attention_mask=attention_mask).logits + + predicted_ids = jnp.argmax(logits, axis=-1) + predicted_trans = processor.batch_decode(predicted_ids) + + EXPECTED_TRANSCRIPTIONS = [ + "a man said to the universe sir i exist", + "sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore", + "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about", + "his instant panic was followed by a small sharp blow high on his chest", + ] + self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) + + def test_inference_pretrained(self): + model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True) + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + "facebook/wav2vec2-large-lv60", return_attention_mask=True + ) + input_speech = self._load_datasamples(2) + + inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True) + + features_shape = ( + inputs_dict["input_values"].shape[0], + model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])), + ) + + mask_time_indices = _compute_mask_indices( + features_shape, + model.config.mask_time_prob, + model.config.mask_time_length, + min_masks=2, + ) + + outputs = model( + inputs_dict.input_values, + attention_mask=inputs_dict.attention_mask, + mask_time_indices=mask_time_indices, + ) + + # compute cosine similarity + cosine_sim = optax.cosine_similarity( + outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8 + ) + + # retrieve cosine sim of masked features + cosine_sim_masked = cosine_sim[mask_time_indices] + + # ... now compare to randomly initialized model + + config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60") + model_rand = FlaxWav2Vec2ForPreTraining(config) + + outputs_rand = model_rand( + inputs_dict.input_values, + attention_mask=inputs_dict.attention_mask, + mask_time_indices=mask_time_indices, + ) + + # compute cosine similarity + cosine_sim_rand = optax.cosine_similarity( + outputs_rand.projected_states, outputs_rand.projected_quantized_states + ) + + # retrieve cosine sim of masked features + cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices] + + # a pretrained wav2vec2 model has learned to predict the quantized latent states + # => the cosine similarity between quantized states and predicted states > 0.5 + # a random wav2vec2 model has not learned to predict the quantized latent states + # => the cosine similarity between quantized states and predicted states is very likely < 0.1 + self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0) diff --git a/utils/check_repo.py b/utils/check_repo.py index 244bd20185..6a17ab5b29 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -102,6 +102,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ "CLIPVisionModel", "FlaxCLIPTextModel", "FlaxCLIPVisionModel", + "FlaxWav2Vec2ForCTC", "DetrForSegmentation", "DPRReader", "FlaubertForQuestionAnswering",