[ASR Examples] Update README for Whisper (#20230)

* [ASR Examples] Update README for seq2seq

* add language info

* add training results

* re-word
This commit is contained in:
Sanchit Gandhi 2022-11-18 11:24:25 +00:00 committed by GitHub
parent 95754b47a6
commit c29a2f7c9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 114 additions and 32 deletions

View File

@ -27,8 +27,8 @@ limitations under the License.
- [Common Voice](#common-voice-ctc)
- [Multilingual Librispeech](#multilingual-librispeech-ctc)
- [Automatic Speech Recognition with Sequence-to-Sequence](#sequence-to-sequence)
- [Single GPU example](#single-gpu-seq2seq)
- [Multi GPU example](#multi-gpu-seq2seq)
- [Whisper Model](#whisper-model)
- [Speech-Encoder-Decoder Model](#warm-started-speech-encoder-decoder-model)
- [Examples](#examples-seq2seq)
- [Librispeech](#librispeech-seq2seq)
@ -246,16 +246,98 @@ they can serve as a baseline to improve upon.
## Sequence to Sequence
The script [`run_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py) can be used to fine-tune any [Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForSpeechSeq2Seq) for automatic speech
recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition) or a custom dataset.
recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition) or a custom dataset. This includes the Whisper model from OpenAI or a warm-started Speech-Encoder-Decoder Model, examples for which are included below.
A very common use case is to leverage a pretrained speech [encoding model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModel),
*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html), [HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html), [XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html) with a pretrained [text decoding model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModel), *e.g.* [Bart](https://huggingface.co/docs/transformers/main/en/model_doc/bart#transformers.BartForCausalLM) to create a [SpeechEnocderDecoderModel](https://huggingface.co/docs/transformers/main/en/model_doc/speechencoderdecoder#speech-encoder-decoder-models).
Consequently, the warm-started Speech-Encoder-Decoder model can be fine-tuned in
this script.
### Whisper Model
We can load all components of the Whisper model directly from the pretrained checkpoint, including the pretrained model weights, feature extractor and tokenizer. We simply have to specify our fine-tuning dataset and training hyperparameters.
As an example, let's instantiate a *Wav2Vec2-2-Bart* model with the `SpeechEnocderDecoderModel` framework:
#### Single GPU Whisper Training
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using a single GPU device in half-precision:
```bash
python run_speech_recognition_seq2seq.py \
--model_name_or_path="openai/whisper-small" \
--dataset_name="mozilla-foundation/common_voice_11_0" \
--dataset_config_name="hi" \
--language="hindi" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--max_steps="5000" \
--output_dir="./whisper-small-hi" \
--per_device_train_batch_size="16" \
--gradient_accumulation_steps="2" \
--per_device_eval_batch_size="16" \
--logging_steps="25" \
--learning_rate="1e-5" \
--warmup_steps="500" \
--evaluation_strategy="steps" \
--eval_steps="1000" \
--save_strategy="steps" \
--save_steps="1000" \
--generation_max_length="225" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--freeze_feature_encoder="False" \
--gradient_checkpointing \
--group_by_length \
--fp16 \
--overwrite_output_dir \
--do_train \
--do_eval \
--predict_with_generate \
--use_auth_token
```
On a single V100, training should take approximately 8 hours, with a final cross-entropy loss of **1e-4** and word error rate of **32.6%**.
First create an empty repo on `hf.co`:
If training on a different language, you should be sure to change the `language` argument. The `language` argument should be omitted for English speech recognition.
#### Multi GPU Whisper Training
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 2 GPU devices in half-precision:
```bash
python -m torch.distributed.launch \
--nproc_per_node 2 run_speech_recognition_seq2seq.py \
--model_name_or_path="openai/whisper-small" \
--dataset_name="mozilla-foundation/common_voice_11_0" \
--dataset_config_name="hi" \
--language="hindi" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--max_steps="5000" \
--output_dir="./whisper-small-hi" \
--per_device_train_batch_size="16" \
--per_device_eval_batch_size="16" \
--logging_steps="25" \
--learning_rate="1e-5" \
--warmup_steps="500" \
--evaluation_strategy="steps" \
--eval_steps="1000" \
--save_strategy="steps" \
--save_steps="1000" \
--generation_max_length="225" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--freeze_feature_encoder="False" \
--gradient_checkpointing \
--group_by_length \
--fp16 \
--overwrite_output_dir \
--do_train \
--do_eval \
--predict_with_generate \
--use_auth_token
```
On two V100s, training should take approximately 4 hours, with a final cross-entropy loss of **1e-4** and word error rate of **32.6%**.
### Warm-Started Speech-Encoder-Decoder Model
A very common use case is to leverage a pretrained speech encoder model,
*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html), [HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html) or [XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), with a pretrained text decoder model, *e.g.* [BART](https://huggingface.co/docs/transformers/main/en/model_doc/bart#transformers.BartForCausalLM) or [GPT-2](https://huggingface.co/docs/transformers/main/en/model_doc/gpt2#transformers.GPT2ForCausalLM), to create a [Speech-Encoder-Decoder Model](https://huggingface.co/docs/transformers/main/en/model_doc/speechencoderdecoder#speech-encoder-decoder-models).
By pairing a pretrained speech model with a pretrained text model, the warm-started model has prior knowledge of both the source audio and target text domains. However, the cross-attention weights between the encoder and decoder are randomly initialised. Thus, the model requires fine-tuning to learn the cross-attention weights and align the encoder mapping with that of the decoder. We can perform this very fine-tuning procedure using the example script.
As an example, let's instantiate a *Wav2Vec2-2-Bart* model with the `SpeechEnocderDecoderModel` framework. First create an empty repo on `hf.co`:
```bash
huggingface-cli repo create wav2vec2-2-bart-base
@ -265,7 +347,7 @@ cd wav2vec2-2-bart-base
Next, run the following script **inside** the just cloned repo:
```py
```python
from transformers import SpeechEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer, Wav2Vec2Processor
# checkpoints to leverage
@ -299,26 +381,26 @@ and link the official `run_speech_recognition_seq2seq.py` script to the folder:
ln -s $(realpath <path/to/transformers>/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py) ./
```
Note that we have added a randomly initialized adapter to `wav2vec2-base` with
`encoder_add_adapter=True` which further samples the output sequence of
`wav2vec2-base` along the time dimension. The reason is that by default a single
output vector of `wav2vec2-base` has a receptive field of *ca.* 25ms (*cf.* with
section *4.2* of the [official Wav2Vec2 paper](https://arxiv.org/pdf/2006.11477.pdf)), which represents a little less a single character. BART on the other hand
makes use of a sentence-piece tokenizer as an input processor so that a single
hidden vector of `bart-base` represents *ca.* 4 characters. To better align
the output of *Wav2Vec2* and *BART*'s hidden vectors for the cross-attention
Note that we have added a randomly initialized _adapter layer_ to `wav2vec2-base` with the argument
`encoder_add_adapter=True`. This adapter sub-samples the output sequence of
`wav2vec2-base` along the time dimension. By default, a single
output vector of `wav2vec2-base` has a receptive field of *ca.* 25ms (*cf.*
Section *4.2* of the [official Wav2Vec2 paper](https://arxiv.org/pdf/2006.11477.pdf)), which represents a little less a single character. On the other hand, BART
makes use of a sentence-piece tokenizer as an input processor, so that a single
hidden vector of `bart-base` represents *ca.* 4 characters. To better align the
receptive field of the *Wav2Vec2* output vectors with *BART*'s hidden-states in the cross-attention
mechanism, we further subsample *Wav2Vec2*'s output by a factor of 8 by
adding a convolution-based adapter.
Having warm-started the speech-encoder-decoder model `<your-user-name>/wav2vec2-2-bart`, we can now fine-tune it on speech recognition.
Having warm-started the speech-encoder-decoder model under `<your-user-name>/wav2vec2-2-bart`, we can now fine-tune it on the task of speech recognition.
In the script [`run_speech_recognition_seq2seq`], we load the warm-started model,
the feature extractor, and the tokenizer, process a speech recognition dataset,
and then make use of the [`Seq2SeqTrainer`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Seq2SeqTrainer).
Note that it is important to also align the decoder's vocabulary with
the speech transcriptions of the dataset. *E.g.* the [`Librispeech`](https://huggingface.co/datasets/librispeech_asr) has only captilized letters in the transcriptions,
whereas BART was pretrained mostly on normalized text. Thus it is recommended to add
`--do_lower_case` to the fine-tuning script when using a warm-started `SpeechEncoderDecoderModel`. The model is fine-tuned on the standard cross-entropy language modeling
feature extractor, and tokenizer, process a speech recognition dataset,
and subsequently make use of the [`Seq2SeqTrainer`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Seq2SeqTrainer) to train our system.
Note that it is important to align the target transcriptions with the decoder's vocabulary. For example, the [`Librispeech`](https://huggingface.co/datasets/librispeech_asr) dataset only contains captilized letters in the transcriptions,
whereas BART was pretrained mostly on normalized text. Thus, it is recommended to add the argument
`--do_lower_case` to the fine-tuning script when using a warm-started `SpeechEncoderDecoderModel`.
The model is fine-tuned on the standard cross-entropy language modeling
loss for sequence-to-sequence (just like *T5* or *BART* in natural language processing).
---
@ -331,11 +413,11 @@ you might want to set the environment variable `OMP_NUM_THREADS` to 1 as follows
OMP_NUM_THREADS=1 python run_speech_recognition_ctc ...
```
If the environment variable is not set, the training script might freeze, *i.e.* see: https://github.com/pytorch/audio/issues/1021#issuecomment-726915239
If the environment variable is not set, the training script might freeze, *i.e.* see: https://github.com/pytorch/audio/issues/1021#issuecomment-726915239.
---
### Single GPU Seq2Seq
#### Single GPU Seq2Seq
The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using a single GPU in half-precision.
@ -376,7 +458,7 @@ python run_speech_recognition_seq2seq.py \
On a single V100 GPU, this script should run in *ca.* 5 hours and yield a
cross-entropy loss of **0.405** and word error rate of **0.0728**.
### Multi GPU Seq2Seq
#### Multi GPU Seq2Seq
The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using 8 GPUs in half-precision.
@ -421,7 +503,7 @@ On 8 V100 GPUs, this script should run in *ca.* 45 minutes and yield a cross-ent
- [Librispeech](https://huggingface.co/datasets/librispeech_asr)
| Dataset | Dataset Config | Pretrained Model | Word error rate on eval | Phoneme error rate on eval | GPU setup | Training time | Fine-tuned Model & Logs | Command to reproduce |
|-------|------------------------------|-------------|---------------|---------------|----------------------|-------------| -------------| ------- |
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) and [facebook/bart-base](https://huggingface.co/facebook/bart-base) | 0.0728 | - | 8 GPU V100 | 45min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base) | [create_model.py](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base/blob/main/create_model.py) & [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base/blob/main/run_librispeech.sh) |
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60) and [facebook/bart-large](https://huggingface.co/facebook/bart-large) | 0.0486 | - | 8 GPU V100 | 1h20min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large) | [create_model.py](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large/blob/main/create_model.py) & [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large/blob/main/run_librispeech.sh) |
| Dataset | Dataset Config | Pretrained Model | Word error rate on eval | Phoneme error rate on eval | GPU setup | Training time | Fine-tuned Model & Logs | Command to reproduce |
|----------------------------------------------------------------|---------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------|----------------------------|------------|---------------|-----------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [Librispeech](https://huggingface.co/datasets/librispeech_asr) | `"clean"` - `"train.100"` | [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) and [facebook/bart-base](https://huggingface.co/facebook/bart-base) | 0.0728 | - | 8 GPU V100 | 45min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base) | [create_model.py](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base/blob/main/create_model.py) & [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base/blob/main/run_librispeech.sh) |
| [Librispeech](https://huggingface.co/datasets/librispeech_asr) | `"clean"` - `"train.100"` | [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60) and [facebook/bart-large](https://huggingface.co/facebook/bart-large) | 0.0486 | - | 8 GPU V100 | 1h20min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large) | [create_model.py](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large/blob/main/create_model.py) & [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large/blob/main/run_librispeech.sh) |