[examples] update whisper fine-tuning (#29938)

* [examples] update whisper fine-tuning

* deprecate forced/suppress tokens

* item assignment

* update readme

* final fix
This commit is contained in:
Sanchit Gandhi 2024-04-26 17:06:03 +01:00 committed by GitHub
parent aafa7ce72b
commit 38b53da38a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 17 deletions

View File

@ -368,6 +368,7 @@ python run_speech_recognition_seq2seq.py \
--dataset_name="mozilla-foundation/common_voice_11_0" \
--dataset_config_name="hi" \
--language="hindi" \
--task="transcribe" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--max_steps="5000" \
@ -384,12 +385,10 @@ python run_speech_recognition_seq2seq.py \
--save_steps="1000" \
--generation_max_length="225" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--freeze_feature_encoder="False" \
--gradient_checkpointing \
--group_by_length \
--fp16 \
--overwrite_output_dir \
--do_train \
@ -399,7 +398,8 @@ python run_speech_recognition_seq2seq.py \
```
On a single V100, training should take approximately 8 hours, with a final cross-entropy loss of **1e-4** and word error rate of **32.6%**.
If training on a different language, you should be sure to change the `language` argument. The `language` argument should be omitted for English speech recognition.
If training on a different language, you should be sure to change the `language` argument. The `language` and `task`
arguments should be omitted for English speech recognition.
#### Multi GPU Whisper Training
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 2 GPU devices in half-precision:
@ -410,6 +410,7 @@ torchrun \
--dataset_name="mozilla-foundation/common_voice_11_0" \
--dataset_config_name="hi" \
--language="hindi" \
--task="transcribe" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--max_steps="5000" \
@ -425,12 +426,10 @@ torchrun \
--save_steps="1000" \
--generation_max_length="225" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--freeze_feature_encoder="False" \
--gradient_checkpointing \
--group_by_length \
--fp16 \
--overwrite_output_dir \
--do_train \

View File

@ -119,16 +119,15 @@ class ModelArguments:
)
forced_decoder_ids: List[List[int]] = field(
default=None,
metadata={
"help": (
"A list of pairs of integers which indicates a mapping from generation indices to token indices "
"that will be forced before sampling. For example, [[0, 123]] means the first generated token "
"will always be a token of index 123."
)
},
metadata={"help": "Deprecated. Please use the `language` and `task` arguments instead."},
)
suppress_tokens: List[int] = field(
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
default=None, metadata={
"help": (
"Deprecated. The use of `suppress_tokens` should not be required for the majority of fine-tuning examples."
"Should you need to use `suppress_tokens`, please manually update them in the fine-tuning script directly."
)
},
)
apply_spec_augment: bool = field(
default=False,
@ -400,8 +399,6 @@ def main():
trust_remote_code=model_args.trust_remote_code,
)
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
# SpecAugment for whisper models
if getattr(config, "model_type", None) == "whisper":
config.update({"apply_spec_augment": model_args.apply_spec_augment})
@ -440,9 +437,35 @@ def main():
model.freeze_encoder()
model.model.encoder.gradient_checkpointing = False
if data_args.language is not None:
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
# We only need to set the language and task ids in a multilingual setting
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
model.generation_config.update(
**{
"language": data_args.language,
"task": data_args.task,
}
)
elif data_args.language is not None:
raise ValueError(
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
"only be set for multilingual checkpoints."
)
# TODO (Sanchit): deprecate these arguments in v4.41
if model_args.forced_decoder_ids is not None:
logger.warning(
"The use of `forced_decoder_ids` is deprecated and will be removed in v4.41."
"Please use the `language` and `task` arguments instead"
)
model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids
if model_args.suppress_tokens is not None:
logger.warning(
"The use of `suppress_tokens` is deprecated and will be removed in v4.41."
"Should you need `suppress_tokens`, please manually set them in the fine-tuning script."
)
model.generation_config.suppress_tokens = model_args.suppress_tokens
# 6. Resample speech dataset if necessary
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate