[examples] update whisper fine-tuning (#29938)
* [examples] update whisper fine-tuning * deprecate forced/suppress tokens * item assignment * update readme * final fix
This commit is contained in:
parent
aafa7ce72b
commit
38b53da38a
|
@ -368,6 +368,7 @@ python run_speech_recognition_seq2seq.py \
|
|||
--dataset_name="mozilla-foundation/common_voice_11_0" \
|
||||
--dataset_config_name="hi" \
|
||||
--language="hindi" \
|
||||
--task="transcribe" \
|
||||
--train_split_name="train+validation" \
|
||||
--eval_split_name="test" \
|
||||
--max_steps="5000" \
|
||||
|
@ -384,12 +385,10 @@ python run_speech_recognition_seq2seq.py \
|
|||
--save_steps="1000" \
|
||||
--generation_max_length="225" \
|
||||
--preprocessing_num_workers="16" \
|
||||
--length_column_name="input_length" \
|
||||
--max_duration_in_seconds="30" \
|
||||
--text_column_name="sentence" \
|
||||
--freeze_feature_encoder="False" \
|
||||
--gradient_checkpointing \
|
||||
--group_by_length \
|
||||
--fp16 \
|
||||
--overwrite_output_dir \
|
||||
--do_train \
|
||||
|
@ -399,7 +398,8 @@ python run_speech_recognition_seq2seq.py \
|
|||
```
|
||||
On a single V100, training should take approximately 8 hours, with a final cross-entropy loss of **1e-4** and word error rate of **32.6%**.
|
||||
|
||||
If training on a different language, you should be sure to change the `language` argument. The `language` argument should be omitted for English speech recognition.
|
||||
If training on a different language, you should be sure to change the `language` argument. The `language` and `task`
|
||||
arguments should be omitted for English speech recognition.
|
||||
|
||||
#### Multi GPU Whisper Training
|
||||
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 2 GPU devices in half-precision:
|
||||
|
@ -410,6 +410,7 @@ torchrun \
|
|||
--dataset_name="mozilla-foundation/common_voice_11_0" \
|
||||
--dataset_config_name="hi" \
|
||||
--language="hindi" \
|
||||
--task="transcribe" \
|
||||
--train_split_name="train+validation" \
|
||||
--eval_split_name="test" \
|
||||
--max_steps="5000" \
|
||||
|
@ -425,12 +426,10 @@ torchrun \
|
|||
--save_steps="1000" \
|
||||
--generation_max_length="225" \
|
||||
--preprocessing_num_workers="16" \
|
||||
--length_column_name="input_length" \
|
||||
--max_duration_in_seconds="30" \
|
||||
--text_column_name="sentence" \
|
||||
--freeze_feature_encoder="False" \
|
||||
--gradient_checkpointing \
|
||||
--group_by_length \
|
||||
--fp16 \
|
||||
--overwrite_output_dir \
|
||||
--do_train \
|
||||
|
|
|
@ -119,16 +119,15 @@ class ModelArguments:
|
|||
)
|
||||
forced_decoder_ids: List[List[int]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"A list of pairs of integers which indicates a mapping from generation indices to token indices "
|
||||
"that will be forced before sampling. For example, [[0, 123]] means the first generated token "
|
||||
"will always be a token of index 123."
|
||||
)
|
||||
},
|
||||
metadata={"help": "Deprecated. Please use the `language` and `task` arguments instead."},
|
||||
)
|
||||
suppress_tokens: List[int] = field(
|
||||
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
|
||||
default=None, metadata={
|
||||
"help": (
|
||||
"Deprecated. The use of `suppress_tokens` should not be required for the majority of fine-tuning examples."
|
||||
"Should you need to use `suppress_tokens`, please manually update them in the fine-tuning script directly."
|
||||
)
|
||||
},
|
||||
)
|
||||
apply_spec_augment: bool = field(
|
||||
default=False,
|
||||
|
@ -400,8 +399,6 @@ def main():
|
|||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
|
||||
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
|
||||
|
||||
# SpecAugment for whisper models
|
||||
if getattr(config, "model_type", None) == "whisper":
|
||||
config.update({"apply_spec_augment": model_args.apply_spec_augment})
|
||||
|
@ -440,9 +437,35 @@ def main():
|
|||
model.freeze_encoder()
|
||||
model.model.encoder.gradient_checkpointing = False
|
||||
|
||||
if data_args.language is not None:
|
||||
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
||||
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
|
||||
# We only need to set the language and task ids in a multilingual setting
|
||||
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
||||
model.generation_config.update(
|
||||
**{
|
||||
"language": data_args.language,
|
||||
"task": data_args.task,
|
||||
}
|
||||
)
|
||||
elif data_args.language is not None:
|
||||
raise ValueError(
|
||||
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
|
||||
"only be set for multilingual checkpoints."
|
||||
)
|
||||
|
||||
# TODO (Sanchit): deprecate these arguments in v4.41
|
||||
if model_args.forced_decoder_ids is not None:
|
||||
logger.warning(
|
||||
"The use of `forced_decoder_ids` is deprecated and will be removed in v4.41."
|
||||
"Please use the `language` and `task` arguments instead"
|
||||
)
|
||||
model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids
|
||||
|
||||
if model_args.suppress_tokens is not None:
|
||||
logger.warning(
|
||||
"The use of `suppress_tokens` is deprecated and will be removed in v4.41."
|
||||
"Should you need `suppress_tokens`, please manually set them in the fine-tuning script."
|
||||
)
|
||||
model.generation_config.suppress_tokens = model_args.suppress_tokens
|
||||
|
||||
# 6. Resample speech dataset if necessary
|
||||
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
||||
|
|
Loading…
Reference in New Issue