[Examples] Generalise Seq2Seq ASR to handle Whisper (#19519)

* merge conflicts

* bos and eos in datacollator

* (temp) hardcode removal of attention mask

* freeze encoder

* actually freeze encoder

* set max length / num beams according to gen kwargs

* (temp) fix tests

* don't pop attn mask

* override return attention mask config from Hub

* Hub configs updated 🤗

* final fixes

* update type annotations

* backward comp
This commit is contained in:
Sanchit Gandhi 2022-11-14 17:45:46 +00:00 committed by GitHub
parent 7ecb039176
commit af1a7c8ca3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 49 additions and 6 deletions

View File

@ -97,6 +97,22 @@ class ModelArguments:
freeze_feature_encoder: bool = field(
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
)
freeze_encoder: bool = field(
default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
)
forced_decoder_ids: List[List[int]] = field(
default=None,
metadata={
"help": (
"A list of pairs of integers which indicates a mapping from generation indices to token indices "
"that will be forced before sampling. For example, [[0, 123]] means the first generated token "
"will always be a token of index 123."
)
},
)
suppress_tokens: List[int] = field(
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
)
@dataclass
@ -187,6 +203,19 @@ class DataTrainingArguments:
default=True,
metadata={"help": "Whether the target text should be lower cased."},
)
language: str = field(
default=None,
metadata={
"help": (
"Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
"only. For English speech recognition, it should be set to `None`."
)
},
)
task: str = field(
default="transcribe",
metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
)
@dataclass
@ -194,7 +223,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor ([`Wav2Vec2Processor`])
processor ([`WhisperProcessor`])
The processor used for processing the data.
decoder_start_token_id (`int`)
The begin-of-sentence of the decoder.
@ -206,7 +235,8 @@ class DataCollatorSpeechSeq2SeqWithPadding:
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features]
model_input_name = self.processor.model_input_names[0]
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
@ -333,6 +363,8 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
@ -360,6 +392,14 @@ def main():
if model_args.freeze_feature_encoder:
model.freeze_feature_encoder()
if model_args.freeze_encoder:
model.freeze_encoder()
model.model.encoder.gradient_checkpointing = False
if data_args.language is not None:
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
# 6. Resample speech dataset if necessary
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
if dataset_sampling_rate != feature_extractor.sampling_rate:
@ -388,8 +428,8 @@ def main():
sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
# process audio length
batch[model_input_name] = inputs.input_values[0]
batch["input_length"] = len(batch["input_values"])
batch[model_input_name] = inputs.get(model_input_name)[0]
batch["input_length"] = len(sample["array"])
# process targets
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
@ -452,7 +492,8 @@ def main():
# 10. Define data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor, decoder_start_token_id=model.config.decoder_start_token_id
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)
# 11. Initialize Trainer
@ -492,7 +533,9 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(
metric_key_prefix="eval", max_length=model.config.max_length, num_beams=model.config.num_beams
metric_key_prefix="eval",
max_length=training_args.generation_max_length,
num_beams=training_args.generation_num_beams,
)
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])