[Flax] Add wav2vec2 (#12271)
* fix_torch_device_generate_test * remove @ * start flax wav2vec2 * save intermediate * forward pass has correct shape * add weight norm * add files * finish ctc * make style * finish gumbel quantizer * correct docstrings * correct some more files * fix vit * finish quality * correct tests * correct docstring * correct tests * start wav2vec2 pretraining script * save intermediate * start pretraining script * finalize pretraining script * finish * finish * small typo * finish * correct * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com> * make style * push Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
3f36a2c064
commit
0d1f67e651
|
@ -411,7 +411,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
|
|
|
@ -99,3 +99,23 @@ TFWav2Vec2ForCTC
|
|||
|
||||
.. autoclass:: transformers.TFWav2Vec2ForCTC
|
||||
:members: call
|
||||
|
||||
|
||||
FlaxWav2Vec2Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxWav2Vec2Model
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxWav2Vec2ForCTC
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxWav2Vec2ForCTC
|
||||
:members: __call__
|
||||
|
||||
FlaxWav2Vec2ForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxWav2Vec2ForPreTraining
|
||||
:members: __call__
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
# Wav2Vec2 Contrastive Loss PreTraining examples
|
||||
|
||||
The following example showcases how to pretrain a wav2vec2 model using the JAX/Flax backend.
|
||||
Pretraining Wav2Vec2 is rather complex, so it is highly recommended to read the
|
||||
[official paper](https://arxiv.org/abs/2006.11477).
|
||||
|
||||
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
|
||||
Models written in JAX/Flax are **immutable** and updated in a purely functional
|
||||
way which enables simple and efficient model parallelism.
|
||||
|
||||
`run_wav2vec2_pretrain_flax.py` is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets library or use your own files (jsonlines or csv), then pretrain the wav2vec2 architectures above on it.
|
||||
|
||||
For custom datasets in `jsonlines` format please see: [the Datasets documentation](https://huggingface.co/docs/datasets/loading_datasets.html#json-files) and you also will find examples of these below.
|
||||
|
||||
Let's start by creating a model repository to save the trained model and logs.
|
||||
Here we call the model `"wav2vec2-base-robust"`, but you can change the model name as you like.
|
||||
|
||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
||||
you are logged in) or via the command line:
|
||||
|
||||
```
|
||||
huggingface-cli repo create wav2vec2-base-robust
|
||||
```
|
||||
|
||||
Next we clone the model repository to add the tokenizer and model files.
|
||||
|
||||
```
|
||||
git clone https://huggingface.co/<your-username>/wav2vec2-base-robust
|
||||
```
|
||||
|
||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||
track them. You can run the following command inside your model repo to do so.
|
||||
|
||||
```
|
||||
cd wav2vec2-base-robust
|
||||
git lfs track "*tfevents*"
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
Next, let's add a symbolic link to the `run_wav2vec2_pretrain_flax`.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="./wav2vec2-base-robust"
|
||||
ln -s ~/transformers/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py ./
|
||||
```
|
||||
|
||||
### Create the model configuration
|
||||
|
||||
Let's first create the model configuration and store it in the model repository.
|
||||
Note that many training parameters can be set in the model configuration including
|
||||
the configuration about the masking distribution (`mask_time_length`, `mask_time_prob`),
|
||||
dropout (`attention_dropout`, ...), the trade-off between the contrastive loss and
|
||||
the diversity loss, etc...
|
||||
Mostly likely you will need to change these parameters depending on your use case.
|
||||
Again, we highly recommend to read the [official paper](https://arxiv.org/abs/2006.11477)
|
||||
to better understand which parameters can be set for pretraining.
|
||||
|
||||
For this example, we will be using a `"base"`-sized model of Wav2Vec2 with robust
|
||||
layer norm and keep most of the default settings.
|
||||
|
||||
```python
|
||||
model_dir="./wav2vec2-base-robust"
|
||||
|
||||
from transformers import Wav2Vec2Config
|
||||
config = Wav2Vec2Config.from_pretrained(
|
||||
"facebook/wav2vec2-base",
|
||||
mask_time_length=10,
|
||||
mask_time_prob=0.05,
|
||||
diversity_loss_weight=0.1,
|
||||
num_negatives=100,
|
||||
do_stable_layer_norm=True,
|
||||
feat_extract_norm="layer",
|
||||
)
|
||||
config.save_pretrained(model_dir)
|
||||
```
|
||||
|
||||
### Create a feature extractor configuration
|
||||
|
||||
Before we can start the training, we need to define
|
||||
a feature extractor that takes care of normalization, etc...
|
||||
|
||||
Here we can also re-use the feature extractor of [wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base) while making sure that padding is allowed.
|
||||
|
||||
|
||||
```python
|
||||
model_dir="./wav2vec2-base-robust"
|
||||
|
||||
from transformers import Wav2Vec2FeatureExtractor
|
||||
config = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base", return_attention_mask=True)
|
||||
config.save_pretrained(model_dir)
|
||||
```
|
||||
|
||||
### Train the model
|
||||
Finally, we can run the example script to train the model:
|
||||
|
||||
```bash
|
||||
./run_wav2vec2_pretrain_flax.py \
|
||||
--output_dir=${MODEL_DIR} \
|
||||
--num_train_epochs="5" \
|
||||
--per_device_train_batch_size="32" \
|
||||
--per_device_eval_batch_size="32" \
|
||||
--learning_rate="5e-4" \
|
||||
--weight_decay="0.01" \
|
||||
--warmup_steps="2000" \
|
||||
--model_name_or_path=${MODEL_DIR} \
|
||||
--dataset_name="librispeech_asr" \
|
||||
--dataset_config_name="clean" \
|
||||
--train_split_name="train.100" \
|
||||
--preprocessing_num_workers="4" \
|
||||
--max_duration_in_seconds="10.0" \
|
||||
--adam_beta1="0.9" \
|
||||
--adam_beta2="0.98" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Note that this script is not fully tested yet, so we cannot ensure that
|
||||
the above script leads to satisfying results.
|
|
@ -0,0 +1,566 @@
|
|||
#!/usr/bin/env python3
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import librosa
|
||||
import optax
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from transformers import (
|
||||
FlaxWav2Vec2ForPreTraining,
|
||||
HfArgumentParser,
|
||||
TrainingArguments,
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
is_tensorboard_available,
|
||||
)
|
||||
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices, _sample_negative_indices
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
freeze_feature_extractor: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
)
|
||||
verbose_logging: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to log verbose messages or not."},
|
||||
)
|
||||
max_gumbel_temperature: Optional[float] = field(
|
||||
default=2.0, metadata={"help": "Maximum temperature for gumbel softmax."}
|
||||
)
|
||||
min_gumbel_temperature: Optional[float] = field(
|
||||
default=0.1, metadata={"help": "Minimum temperature for gumbel softmax."}
|
||||
)
|
||||
gumbel_temperature_decay: Optional[float] = field(
|
||||
default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
|
||||
)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
|
||||
Using `HfArgumentParser` we can turn this class
|
||||
into argparse arguments to be able to specify them on
|
||||
the command line.
|
||||
"""
|
||||
|
||||
dataset_name: str = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_split_name: Optional[str] = field(
|
||||
default="train",
|
||||
metadata={
|
||||
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
||||
},
|
||||
)
|
||||
validation_split_name: Optional[str] = field(
|
||||
default="validation",
|
||||
metadata={
|
||||
"help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
|
||||
},
|
||||
)
|
||||
speech_file_column: Optional[str] = field(
|
||||
default="file",
|
||||
metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_duration_in_seconds: Optional[float] = field(
|
||||
default=20.0, metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"}
|
||||
)
|
||||
pad_to_multiple_of: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "If set will pad the sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxDataCollatorForWav2Vec2Pretraining:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs received and prepare masked indices
|
||||
for self-supervised pretraining.
|
||||
|
||||
Args:
|
||||
model (:class:`~transformers.FlaxWav2Vec2ForPreTraining`):
|
||||
The Wav2Vec2 model used for pretraining. The data collator needs to have access
|
||||
to config and ``_get_feat_extract_output_lengths`` function for correct padding.
|
||||
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
|
||||
The processor used for proccessing the data.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
"""
|
||||
|
||||
model: FlaxWav2Vec2ForPreTraining
|
||||
feature_extractor: Wav2Vec2FeatureExtractor
|
||||
padding: Union[bool, str] = "longest"
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
max_length: Optional[int] = None
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
|
||||
# reformat list to dict and set to pytorch format
|
||||
batch = self.feature_extractor.pad(
|
||||
features,
|
||||
max_length=self.max_length,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="np",
|
||||
)
|
||||
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
|
||||
|
||||
# sample randomly masked indices
|
||||
batch["mask_time_indices"] = _compute_mask_indices(
|
||||
(batch["input_values"].shape[0], mask_indices_seq_length),
|
||||
self.model.config.mask_time_prob,
|
||||
self.model.config.mask_time_length,
|
||||
min_masks=2,
|
||||
)
|
||||
|
||||
# sample indices to take for negative vectors
|
||||
batch["sampled_negative_indices"] = _sample_negative_indices(
|
||||
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
|
||||
self.model.config.num_negatives,
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
logging_level = logging.WARNING
|
||||
if model_args.verbose_logging:
|
||||
logging_level = logging.DEBUG
|
||||
logger.setLevel(logging_level)
|
||||
|
||||
|
||||
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
||||
summary_writer.scalar("train_time", train_time, step)
|
||||
|
||||
train_metrics = get_metrics(train_metrics)
|
||||
for key, vals in train_metrics.items():
|
||||
tag = f"train_{key}"
|
||||
for i, val in enumerate(vals):
|
||||
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
||||
|
||||
for metric_name, value in eval_metrics.items():
|
||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||
|
||||
|
||||
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
||||
num_samples = len(samples_idx)
|
||||
samples_to_remove = num_samples % batch_size
|
||||
|
||||
if samples_to_remove != 0:
|
||||
samples_idx = samples_idx[:-samples_to_remove]
|
||||
sections_split = num_samples // batch_size
|
||||
batch_idx = np.split(samples_idx, sections_split)
|
||||
return batch_idx
|
||||
|
||||
|
||||
def compute_contrastive_loss(
|
||||
quantized_features, transformer_features, negative_indices, mask_time_indices, logits_temp, num_negatives
|
||||
):
|
||||
batch_size, sequence_length, hidden_size = quantized_features.shape
|
||||
|
||||
# take negative vectors from sampled indices
|
||||
quantized_negatives = quantized_features.reshape(-1, hidden_size)[negative_indices.reshape(-1)]
|
||||
quantized_negatives = quantized_negatives.reshape(
|
||||
batch_size, sequence_length, num_negatives, hidden_size
|
||||
).transpose(2, 0, 1, 3)
|
||||
|
||||
target_features = jnp.concatenate([quantized_features[None, :], quantized_negatives], axis=0)
|
||||
loss_logits = optax.cosine_similarity(transformer_features, target_features)
|
||||
loss_logits = loss_logits / logits_temp
|
||||
|
||||
neg_is_pos = (quantized_features == quantized_negatives).all(-1)
|
||||
neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0)
|
||||
|
||||
# make sure incorrectly sampled vectors don't contribute to loss
|
||||
loss_logits = jnp.where(neg_is_pos, -1e9, loss_logits)
|
||||
|
||||
predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0])
|
||||
targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten()
|
||||
|
||||
target_mask = jnp.where(targets >= 0, 1.0, 0.0)
|
||||
contrastive_loss = optax.softmax_cross_entropy(predictions, onehot(targets, predictions.shape[-1])) * target_mask
|
||||
|
||||
contrastive_loss = contrastive_loss.sum()
|
||||
|
||||
return contrastive_loss
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
configure_logger(model_args, training_args)
|
||||
|
||||
# Downloading and loading a dataset from the hub.
|
||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
||||
|
||||
if "validation" not in datasets.keys():
|
||||
# make sure only "validation" and "train" keys remain"
|
||||
datasets = DatasetDict()
|
||||
datasets["validation"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
datasets["train"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
else:
|
||||
# make sure only "validation" and "train" keys remain"
|
||||
datasets = DatasetDict()
|
||||
datasets["validation"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split="validation",
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
datasets["train"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"{data_args.train_split_name}",
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
# only normalized-inputs-training is supported
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
|
||||
)
|
||||
|
||||
def prepare_dataset(batch):
|
||||
# check that all files have the correct sampling rate
|
||||
batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
|
||||
return batch
|
||||
|
||||
# load audio files into numpy arrays
|
||||
vectorized_datasets = datasets.map(
|
||||
prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names
|
||||
)
|
||||
|
||||
# filter audio files that are too long
|
||||
vectorized_datasets = vectorized_datasets.filter(
|
||||
lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
|
||||
)
|
||||
|
||||
def normalize(batch):
|
||||
return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
|
||||
|
||||
# normalize and transform to `BatchFeatures`
|
||||
vectorized_datasets = vectorized_datasets.map(
|
||||
normalize,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
remove_columns=vectorized_datasets["train"].column_names,
|
||||
)
|
||||
|
||||
# pretraining is only supported for "newer" stable layer norm architecture
|
||||
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
|
||||
config = Wav2Vec2Config.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
gradient_checkpointing=model_args.gradient_checkpointing,
|
||||
)
|
||||
|
||||
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
|
||||
raise ValueError(
|
||||
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
|
||||
)
|
||||
|
||||
model = FlaxWav2Vec2ForPreTraining(
|
||||
config, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
)
|
||||
|
||||
data_collator = FlaxDataCollatorForWav2Vec2Pretraining(
|
||||
model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of
|
||||
)
|
||||
|
||||
# Enable tensorboard only on the master node
|
||||
has_tensorboard = is_tensorboard_available()
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
try:
|
||||
from flax.metrics.tensorboard import SummaryWriter
|
||||
|
||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
||||
except ImportError as ie:
|
||||
has_tensorboard = False
|
||||
logger.warning(
|
||||
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unable to display metrics through TensorBoard because the package is not installed: "
|
||||
"Please run pip install tensorboard to enable."
|
||||
)
|
||||
|
||||
# Initialize our training
|
||||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
gumbel_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
|
||||
num_train_steps = len(vectorized_datasets["train"]) // train_batch_size * num_epochs
|
||||
|
||||
# Create learning rate schedule
|
||||
warmup_fn = optax.linear_schedule(
|
||||
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
||||
)
|
||||
decay_fn = optax.linear_schedule(
|
||||
init_value=training_args.learning_rate,
|
||||
end_value=0,
|
||||
transition_steps=num_train_steps - training_args.warmup_steps,
|
||||
)
|
||||
linear_decay_lr_schedule_fn = optax.join_schedules(
|
||||
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
||||
)
|
||||
|
||||
# We use Optax's "masking" functionality to not apply weight decay
|
||||
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
||||
# mask boolean with the same structure as the parameters.
|
||||
# The mask is True for parameters that should be decayed.
|
||||
def decay_mask_fn(params):
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
flat_mask = {
|
||||
path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
|
||||
for path in flat_params
|
||||
}
|
||||
return traverse_util.unflatten_dict(flat_mask)
|
||||
|
||||
# create adam optimizer
|
||||
adamw = optax.adamw(
|
||||
learning_rate=linear_decay_lr_schedule_fn,
|
||||
b1=training_args.adam_beta1,
|
||||
b2=training_args.adam_beta2,
|
||||
eps=training_args.adam_epsilon,
|
||||
weight_decay=training_args.weight_decay,
|
||||
mask=decay_mask_fn,
|
||||
)
|
||||
|
||||
# Setup train state and define training hyper-parameters
|
||||
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
||||
num_negatives = model.config.num_negatives
|
||||
contrastive_logits_temperature = model.config.contrastive_logits_temperature
|
||||
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
|
||||
diversity_loss_weight = model.config.diversity_loss_weight
|
||||
|
||||
# Define gradient update step fn
|
||||
def train_step(state, batch, dropout_rng, gumbel_rng):
|
||||
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
||||
gumbel_rng, new_gumbel_rng = jax.random.split(gumbel_rng)
|
||||
|
||||
def loss_fn(params):
|
||||
negative_indices = batch.pop("sampled_negative_indices")
|
||||
|
||||
gumbel_temperature = jnp.clip(
|
||||
model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay ** state.step,
|
||||
a_min=model_args.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
outputs = state.apply_fn(
|
||||
**batch,
|
||||
gumbel_temperature=gumbel_temperature,
|
||||
params=params,
|
||||
dropout_rng=dropout_rng,
|
||||
gumbel_rng=gumbel_rng,
|
||||
train=True,
|
||||
)
|
||||
|
||||
contrastive_loss = compute_contrastive_loss(
|
||||
outputs.projected_quantized_states,
|
||||
outputs.projected_states,
|
||||
negative_indices,
|
||||
batch["mask_time_indices"],
|
||||
contrastive_logits_temperature,
|
||||
num_negatives,
|
||||
)
|
||||
|
||||
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
|
||||
loss = contrastive_loss + diversity_loss_weight * diversity_loss
|
||||
|
||||
return loss
|
||||
|
||||
grad_fn = jax.value_and_grad(loss_fn)
|
||||
loss, grad = grad_fn(state.params)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
new_state = state.apply_gradients(grads=grad)
|
||||
|
||||
metrics = jax.lax.pmean(
|
||||
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
|
||||
)
|
||||
|
||||
return new_state, metrics, new_dropout_rng, new_gumbel_rng
|
||||
|
||||
# Create parallel version of the train step
|
||||
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
||||
|
||||
# Define eval fn
|
||||
def eval_step(params, batch):
|
||||
negative_indices = batch.pop("sampled_negative_indices")
|
||||
|
||||
outputs = model(**batch, params=params, train=False)
|
||||
|
||||
contrastive_loss = compute_contrastive_loss(
|
||||
outputs.projected_quantized_states,
|
||||
outputs.projected_states,
|
||||
negative_indices,
|
||||
batch["mask_time_indices"],
|
||||
contrastive_logits_temperature,
|
||||
num_negatives,
|
||||
)
|
||||
|
||||
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
|
||||
loss = contrastive_loss + diversity_loss_weight * diversity_loss
|
||||
|
||||
# summarize metrics
|
||||
metrics = {"loss": loss.mean(), "codevector_perplexity": outputs.codevector_perplexity}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
|
||||
return metrics
|
||||
|
||||
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
||||
|
||||
# Replicate the train state on each device
|
||||
state = jax_utils.replicate(state)
|
||||
|
||||
train_time = 0
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
train_start = time.time()
|
||||
train_metrics = []
|
||||
|
||||
# Create sampling rng
|
||||
rng, input_rng = jax.random.split(rng)
|
||||
|
||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||
num_train_samples = len(vectorized_datasets["train"])
|
||||
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
||||
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
||||
|
||||
# Gather the indexes for creating the batch and do a training step
|
||||
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
||||
samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
|
||||
model_inputs = data_collator(samples)
|
||||
model_inputs = shard(model_inputs.data)
|
||||
|
||||
# Model forward
|
||||
state, train_metric, dropout_rngs, gumbel_rngs = p_train_step(
|
||||
state, model_inputs, dropout_rngs, gumbel_rngs
|
||||
)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
train_time += time.time() - train_start
|
||||
|
||||
epochs.write(
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
||||
)
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
num_eval_samples = len(vectorized_datasets["validation"])
|
||||
eval_samples_idx = jnp.arange(num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
|
||||
eval_metrics = []
|
||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
samples = [vectorized_datasets["validation"][int(idx)] for idx in batch_idx]
|
||||
model_inputs = data_collator(samples)
|
||||
|
||||
# Model forward
|
||||
model_inputs = shard(model_inputs.data)
|
||||
metrics = p_eval_step(state.params, model_inputs)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# get eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
||||
|
||||
# Update progress bar
|
||||
epochs.write(
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})"
|
||||
)
|
||||
|
||||
# Save metrics
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
|
||||
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1643,6 +1643,9 @@ if is_flax_available():
|
|||
)
|
||||
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
|
||||
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
|
||||
_import_structure["models.wav2vec2"].extend(
|
||||
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
|
||||
)
|
||||
else:
|
||||
from .utils import dummy_flax_objects
|
||||
|
||||
|
@ -3023,6 +3026,12 @@ if TYPE_CHECKING:
|
|||
)
|
||||
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
|
||||
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
|
||||
from .models.wav2vec2 import (
|
||||
FlaxWav2Vec2ForCTC,
|
||||
FlaxWav2Vec2ForPreTraining,
|
||||
FlaxWav2Vec2Model,
|
||||
FlaxWav2Vec2PreTrainedModel,
|
||||
)
|
||||
else:
|
||||
# Import the same objects as dummies to get them in the namespace.
|
||||
# They will raise an import error if the user tries to instantiate / use them.
|
||||
|
|
|
@ -64,6 +64,7 @@ from ..roberta.modeling_flax_roberta import (
|
|||
)
|
||||
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
|
||||
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
||||
from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model
|
||||
from .auto_factory import auto_class_factory
|
||||
from .configuration_auto import (
|
||||
BartConfig,
|
||||
|
@ -75,6 +76,7 @@ from .configuration_auto import (
|
|||
RobertaConfig,
|
||||
T5Config,
|
||||
ViTConfig,
|
||||
Wav2Vec2Config,
|
||||
)
|
||||
|
||||
|
||||
|
@ -93,6 +95,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
|
|||
(CLIPConfig, FlaxCLIPModel),
|
||||
(ViTConfig, FlaxViTModel),
|
||||
(T5Config, FlaxT5Model),
|
||||
(Wav2Vec2Config, FlaxWav2Vec2Model),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -105,6 +108,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
|||
(BartConfig, FlaxBartForConditionalGeneration),
|
||||
(ElectraConfig, FlaxElectraForPreTraining),
|
||||
(T5Config, FlaxT5ForConditionalGeneration),
|
||||
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
|
||||
from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
@ -37,7 +37,6 @@ if is_torch_available():
|
|||
"Wav2Vec2PreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_wav2vec2"] = [
|
||||
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
|
@ -46,6 +45,14 @@ if is_tf_available():
|
|||
"TFWav2Vec2PreTrainedModel",
|
||||
]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_wav2vec2"] = [
|
||||
"FlaxWav2Vec2ForCTC",
|
||||
"FlaxWav2Vec2ForPreTraining",
|
||||
"FlaxWav2Vec2Model",
|
||||
"FlaxWav2Vec2PreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
|
||||
|
@ -71,6 +78,14 @@ if TYPE_CHECKING:
|
|||
TFWav2Vec2PreTrainedModel,
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_tf_wav2vec2 import (
|
||||
FlaxWav2Vec2ForCTC,
|
||||
FlaxWav2Vec2ForPreTraining,
|
||||
FlaxWav2Vec2Model,
|
||||
FlaxWav2Vec2PreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -654,3 +654,31 @@ class FlaxViTPreTrainedModel:
|
|||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxWav2Vec2ForCTC:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxWav2Vec2ForPreTraining:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxWav2Vec2Model:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxWav2Vec2PreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
|
|
@ -0,0 +1,398 @@
|
|||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import Wav2Vec2Config, is_flax_available
|
||||
from transformers.testing_utils import require_datasets, require_flax, require_soundfile, slow
|
||||
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
|
||||
FlaxWav2Vec2ForCTC,
|
||||
FlaxWav2Vec2ForPreTraining,
|
||||
FlaxWav2Vec2GumbelVectorQuantizer,
|
||||
FlaxWav2Vec2Model,
|
||||
_compute_mask_indices,
|
||||
_sample_negative_indices,
|
||||
)
|
||||
|
||||
|
||||
class FlaxWav2Vec2ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=1024, # speech is longer
|
||||
is_training=False,
|
||||
hidden_size=24,
|
||||
feat_extract_norm="layer",
|
||||
feat_extract_dropout=0.0,
|
||||
feat_extract_activation="gelu",
|
||||
conv_dim=(32, 32, 32),
|
||||
conv_stride=(4, 4, 4),
|
||||
conv_kernel=(8, 8, 8),
|
||||
conv_bias=False,
|
||||
num_conv_pos_embeddings=16,
|
||||
num_conv_pos_embedding_groups=2,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=2,
|
||||
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
|
||||
intermediate_size=20,
|
||||
layer_norm_eps=1e-5,
|
||||
hidden_act="gelu",
|
||||
initializer_range=0.02,
|
||||
vocab_size=32,
|
||||
do_stable_layer_norm=True,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.feat_extract_norm = feat_extract_norm
|
||||
self.feat_extract_dropout = feat_extract_dropout
|
||||
self.feat_extract_activation = feat_extract_activation
|
||||
self.conv_dim = conv_dim
|
||||
self.conv_stride = conv_stride
|
||||
self.conv_kernel = conv_kernel
|
||||
self.conv_bias = conv_bias
|
||||
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
||||
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.intermediate_size = intermediate_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.scope = scope
|
||||
|
||||
output_seq_length = self.seq_length
|
||||
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
|
||||
output_seq_length = (output_seq_length - (kernel - 1)) / stride
|
||||
self.output_seq_length = int(math.ceil(output_seq_length))
|
||||
self.encoder_seq_length = self.output_seq_length
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
config = Wav2Vec2Config(
|
||||
do_stable_layer_norm=self.do_stable_layer_norm,
|
||||
hidden_size=self.hidden_size,
|
||||
feat_extract_norm=self.feat_extract_norm,
|
||||
feat_extract_dropout=self.feat_extract_dropout,
|
||||
feat_extract_activation=self.feat_extract_activation,
|
||||
conv_dim=self.conv_dim,
|
||||
conv_stride=self.conv_stride,
|
||||
conv_kernel=self.conv_kernel,
|
||||
conv_bias=self.conv_bias,
|
||||
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
|
||||
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
intermediate_size=self.intermediate_size,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
hidden_act=self.hidden_act,
|
||||
initializer_range=self.initializer_range,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
return config, input_values, attention_mask
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_values, attention_mask = config_and_inputs
|
||||
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(FlaxWav2Vec2Model, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForPreTraining) if is_flax_available() else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxWav2Vec2ModelTester(self)
|
||||
|
||||
def test_train(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
input_values = inputs_dict["input_values"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
|
||||
model = FlaxWav2Vec2ForPreTraining(config)
|
||||
|
||||
features_shape = (
|
||||
input_values.shape[0],
|
||||
model._get_feat_extract_output_lengths(np.array(input_values.shape[1])),
|
||||
)
|
||||
|
||||
batch_size, sequence_length = features_shape[:2]
|
||||
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
|
||||
dropout_rng, gumbel_rng = jax.random.split(jax.random.PRNGKey(0))
|
||||
|
||||
output = model(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
mask_time_indices=mask_time_indices,
|
||||
train=True,
|
||||
dropout_rng=dropout_rng,
|
||||
gumbel_rng=gumbel_rng,
|
||||
)[0]
|
||||
|
||||
self.assertTrue(output.shape == (batch_size, sequence_length, model.config.proj_codevector_dim))
|
||||
|
||||
# overwrite because of `input_values`
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.__call__)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["input_values", "attention_mask"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
@slow
|
||||
# overwrite because of `input_values`
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_values, attention_mask=None, **kwargs):
|
||||
return model(input_values=input_values, attention_mask=attention_mask, **kwargs)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True)
|
||||
outputs = model(np.ones((1, 1024), dtype="f4"))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxWav2Vec2UtilsTest(unittest.TestCase):
|
||||
def test_compute_mask_indices(self):
|
||||
batch_size = 4
|
||||
sequence_length = 60
|
||||
mask_prob = 0.5
|
||||
mask_length = 1
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||
|
||||
def test_compute_mask_indices_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 80
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
|
||||
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
|
||||
def test_compute_perplexity(self):
|
||||
probs = np.arange(100).reshape(2, 5, 10) / 100
|
||||
|
||||
ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs)
|
||||
self.assertTrue(abs(ppl.item() - 141.4291) < 1e-3)
|
||||
|
||||
# mask half of the input
|
||||
mask = np.ones((2,), dtype=np.bool)
|
||||
mask[0] = 0
|
||||
|
||||
ppl = FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexity(probs, mask)
|
||||
self.assertTrue(abs(ppl.item() - 58.6757) < 1e-3)
|
||||
|
||||
def test_sample_negatives(self):
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
|
||||
features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape(
|
||||
sequence_length, hidden_size
|
||||
) # each value in vector consits of same value
|
||||
features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size))
|
||||
|
||||
negative_indices = _sample_negative_indices(features.shape, num_negatives)
|
||||
|
||||
features = features.reshape(-1, hidden_size) # BTC => (BxT)C
|
||||
# take negative vectors from sampled indices
|
||||
sampled_negatives = features[negative_indices.reshape(-1)]
|
||||
negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose(
|
||||
2, 0, 1, 3
|
||||
)
|
||||
|
||||
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||
|
||||
# make sure no negatively sampled vector is actually a positive one
|
||||
for negative in negatives:
|
||||
self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0)
|
||||
|
||||
# make sure that full vectors are sampled and not values of vectors
|
||||
# => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
|
||||
@require_flax
|
||||
@require_datasets
|
||||
@require_soundfile
|
||||
@slow
|
||||
class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
ids = [f"1272-141231-000{i}" for i in range(num_samples)]
|
||||
|
||||
# map files to raw
|
||||
def map_to_array(batch):
|
||||
speech, _ = sf.read(batch["file"])
|
||||
batch["speech"] = speech
|
||||
return batch
|
||||
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
ds = ds.filter(lambda x: x["id"] in ids).sort("id").map(map_to_array)
|
||||
|
||||
return ds["speech"][:num_samples]
|
||||
|
||||
def test_inference_ctc_robust_batched(self):
|
||||
model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", from_pt=True)
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
|
||||
input_values = inputs.input_values
|
||||
attention_mask = inputs.attention_mask
|
||||
|
||||
logits = model(input_values, attention_mask=attention_mask).logits
|
||||
|
||||
predicted_ids = jnp.argmax(logits, axis=-1)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
||||
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
|
||||
"his instant panic was followed by a small sharp blow high on his chest",
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_pretrained(self):
|
||||
model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"facebook/wav2vec2-large-lv60", return_attention_mask=True
|
||||
)
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs_dict = feature_extractor(input_speech, return_tensors="np", padding=True)
|
||||
|
||||
features_shape = (
|
||||
inputs_dict["input_values"].shape[0],
|
||||
model._get_feat_extract_output_lengths(np.array(inputs_dict["input_values"].shape[1])),
|
||||
)
|
||||
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
min_masks=2,
|
||||
)
|
||||
|
||||
outputs = model(
|
||||
inputs_dict.input_values,
|
||||
attention_mask=inputs_dict.attention_mask,
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
# compute cosine similarity
|
||||
cosine_sim = optax.cosine_similarity(
|
||||
outputs.projected_states, outputs.projected_quantized_states, epsilon=1e-8
|
||||
)
|
||||
|
||||
# retrieve cosine sim of masked features
|
||||
cosine_sim_masked = cosine_sim[mask_time_indices]
|
||||
|
||||
# ... now compare to randomly initialized model
|
||||
|
||||
config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large-lv60")
|
||||
model_rand = FlaxWav2Vec2ForPreTraining(config)
|
||||
|
||||
outputs_rand = model_rand(
|
||||
inputs_dict.input_values,
|
||||
attention_mask=inputs_dict.attention_mask,
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
# compute cosine similarity
|
||||
cosine_sim_rand = optax.cosine_similarity(
|
||||
outputs_rand.projected_states, outputs_rand.projected_quantized_states
|
||||
)
|
||||
|
||||
# retrieve cosine sim of masked features
|
||||
cosine_sim_masked_rand = cosine_sim_rand[mask_time_indices]
|
||||
|
||||
# a pretrained wav2vec2 model has learned to predict the quantized latent states
|
||||
# => the cosine similarity between quantized states and predicted states > 0.5
|
||||
# a random wav2vec2 model has not learned to predict the quantized latent states
|
||||
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
|
||||
self.assertTrue(cosine_sim_masked.mean().item() - 5 * cosine_sim_masked_rand.mean().item() > 0)
|
|
@ -102,6 +102,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
|||
"CLIPVisionModel",
|
||||
"FlaxCLIPTextModel",
|
||||
"FlaxCLIPVisionModel",
|
||||
"FlaxWav2Vec2ForCTC",
|
||||
"DetrForSegmentation",
|
||||
"DPRReader",
|
||||
"FlaubertForQuestionAnswering",
|
||||
|
|
Loading…
Reference in New Issue