[Examples] Generalise Seq2Seq ASR to handle Whisper (#19519)
* merge conflicts
* bos and eos in datacollator
* (temp) hardcode removal of attention mask
* freeze encoder
* actually freeze encoder
* set max length / num beams according to gen kwargs
* (temp) fix tests
* don't pop attn mask
* override return attention mask config from Hub
* Hub configs updated 🤗
* final fixes
* update type annotations
* backward comp
This commit is contained in:
parent
7ecb039176
commit
af1a7c8ca3
|
@ -97,6 +97,22 @@ class ModelArguments:
|
|||
freeze_feature_encoder: bool = field(
|
||||
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
|
||||
)
|
||||
freeze_encoder: bool = field(
|
||||
default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
|
||||
)
|
||||
forced_decoder_ids: List[List[int]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"A list of pairs of integers which indicates a mapping from generation indices to token indices "
|
||||
"that will be forced before sampling. For example, [[0, 123]] means the first generated token "
|
||||
"will always be a token of index 123."
|
||||
)
|
||||
},
|
||||
)
|
||||
suppress_tokens: List[int] = field(
|
||||
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -187,6 +203,19 @@ class DataTrainingArguments:
|
|||
default=True,
|
||||
metadata={"help": "Whether the target text should be lower cased."},
|
||||
)
|
||||
language: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
|
||||
"only. For English speech recognition, it should be set to `None`."
|
||||
)
|
||||
},
|
||||
)
|
||||
task: str = field(
|
||||
default="transcribe",
|
||||
metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -194,7 +223,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|||
"""
|
||||
Data collator that will dynamically pad the inputs received.
|
||||
Args:
|
||||
processor ([`Wav2Vec2Processor`])
|
||||
processor ([`WhisperProcessor`])
|
||||
The processor used for processing the data.
|
||||
decoder_start_token_id (`int`)
|
||||
The begin-of-sentence of the decoder.
|
||||
|
@ -206,7 +235,8 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
# split inputs and labels since they have to be of different lengths and need
|
||||
# different padding methods
|
||||
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
||||
model_input_name = self.processor.model_input_names[0]
|
||||
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
|
||||
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
||||
|
||||
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
||||
|
@ -333,6 +363,8 @@ def main():
|
|||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
|
||||
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
|
@ -360,6 +392,14 @@ def main():
|
|||
if model_args.freeze_feature_encoder:
|
||||
model.freeze_feature_encoder()
|
||||
|
||||
if model_args.freeze_encoder:
|
||||
model.freeze_encoder()
|
||||
model.model.encoder.gradient_checkpointing = False
|
||||
|
||||
if data_args.language is not None:
|
||||
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
||||
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
||||
|
||||
# 6. Resample speech dataset if necessary
|
||||
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
||||
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
||||
|
@ -388,8 +428,8 @@ def main():
|
|||
sample = batch[audio_column_name]
|
||||
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
|
||||
# process audio length
|
||||
batch[model_input_name] = inputs.input_values[0]
|
||||
batch["input_length"] = len(batch["input_values"])
|
||||
batch[model_input_name] = inputs.get(model_input_name)[0]
|
||||
batch["input_length"] = len(sample["array"])
|
||||
|
||||
# process targets
|
||||
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
|
||||
|
@ -452,7 +492,8 @@ def main():
|
|||
|
||||
# 10. Define data collator
|
||||
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
||||
processor=processor, decoder_start_token_id=model.config.decoder_start_token_id
|
||||
processor=processor,
|
||||
decoder_start_token_id=model.config.decoder_start_token_id,
|
||||
)
|
||||
|
||||
# 11. Initialize Trainer
|
||||
|
@ -492,7 +533,9 @@ def main():
|
|||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate(
|
||||
metric_key_prefix="eval", max_length=model.config.max_length, num_beams=model.config.num_beams
|
||||
metric_key_prefix="eval",
|
||||
max_length=training_args.generation_max_length,
|
||||
num_beams=training_args.generation_num_beams,
|
||||
)
|
||||
max_eval_samples = (
|
||||
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
|
||||
|
|
Loading…
Reference in New Issue