Fix Whisper Conversion Script: Correct decoder_attention_heads and _download function (#26834)
* Fix error in convert_openai_to_hf.py: "_download() missing 1 required positional argument: root" * Fix error in convert_openai_to_hf.py: "TypeError: byte indices must be integers or slices, not str" * Fix decoder_attention_heads value in convert_openai_to_hf.py. Correct the assignment for `decoder_attention_heads` in the conversion script for the Whisper model. * Black reformat convert_openai_to_hf.py file. * Fix Whisper model configuration defaults (for Tiny). - Correct encoder/decoder layers and attention heads count. - Update model width (`d_model`) to 384. * Add docstring to the convert_openai_to_hf.py script with a doctest * Add shebang and +x permission to the convert_openai_to_hf.py * convert_openai_to_hf.py: reuse the read model_bytes in the _download() function * Move convert_openai_to_hf.py doctest example to whisper.md * whisper.md: Add an inference example to the Conversion section. * whisper.md: remove `model.config.forced_decoder_ids` from examples (deprecated) * whisper.md: Remove "## Format Conversion" section; not used by users * whisper.md: Use librispeech_asr_dummy dataset and load_dataset()
This commit is contained in:
parent
90b4adc1f1
commit
606d90845f
|
@ -34,6 +34,42 @@ The original code can be found [here](https://github.com/openai/whisper).
|
|||
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
|
||||
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
|
||||
|
||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts).
|
||||
The original code can be found [here](https://github.com/openai/whisper).
|
||||
|
||||
## Inference
|
||||
|
||||
Here is a step-by-step guide to transcribing an audio sample using a pre-trained Whisper model:
|
||||
|
||||
```python
|
||||
>>> from datasets import load_dataset
|
||||
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||
|
||||
>>> # Select an audio file and read it:
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> audio_sample = ds[0]["audio"]
|
||||
>>> waveform = audio_sample["array"]
|
||||
>>> sampling_rate = audio_sample["sampling_rate"]
|
||||
|
||||
>>> # Load the Whisper model in Hugging Face format:
|
||||
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
|
||||
>>> # Use the model and processor to transcribe the audio:
|
||||
>>> input_features = processor(
|
||||
... waveform, sampling_rate=sampling_rate, return_tensors="pt"
|
||||
... ).input_features
|
||||
|
||||
>>> # Generate token ids
|
||||
>>> predicted_ids = model.generate(input_features)
|
||||
|
||||
>>> # Decode token ids to text
|
||||
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||
|
||||
>>> transcription[0]
|
||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
||||
```
|
||||
|
||||
## WhisperConfig
|
||||
|
||||
[[autodoc]] WhisperConfig
|
||||
|
|
|
@ -77,13 +77,13 @@ class WhisperConfig(PretrainedConfig):
|
|||
num_mel_bins (`int`, *optional*, defaults to 80):
|
||||
Number of mel features used per input features. Should correspond to the value used in the
|
||||
`WhisperProcessor` class.
|
||||
encoder_layers (`int`, *optional*, defaults to 6):
|
||||
encoder_layers (`int`, *optional*, defaults to 4):
|
||||
Number of encoder layers.
|
||||
decoder_layers (`int`, *optional*, defaults to 6):
|
||||
decoder_layers (`int`, *optional*, defaults to 4):
|
||||
Number of decoder layers.
|
||||
encoder_attention_heads (`int`, *optional*, defaults to 4):
|
||||
encoder_attention_heads (`int`, *optional*, defaults to 6):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
decoder_attention_heads (`int`, *optional*, defaults to 4):
|
||||
decoder_attention_heads (`int`, *optional*, defaults to 6):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
encoder_ffn_dim (`int`, *optional*, defaults to 1536):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
|
||||
|
@ -106,7 +106,7 @@ class WhisperConfig(PretrainedConfig):
|
|||
activation_function (`str`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
d_model (`int`, *optional*, defaults to 256):
|
||||
d_model (`int`, *optional*, defaults to 384):
|
||||
Dimensionality of the layers.
|
||||
dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
|
@ -197,10 +197,10 @@ class WhisperConfig(PretrainedConfig):
|
|||
self,
|
||||
vocab_size=51865,
|
||||
num_mel_bins=80,
|
||||
encoder_layers=6,
|
||||
encoder_attention_heads=4,
|
||||
decoder_layers=6,
|
||||
decoder_attention_heads=4,
|
||||
encoder_layers=4,
|
||||
encoder_attention_heads=6,
|
||||
decoder_layers=4,
|
||||
decoder_attention_heads=6,
|
||||
decoder_ffn_dim=1536,
|
||||
encoder_ffn_dim=1536,
|
||||
encoder_layerdrop=0.0,
|
||||
|
@ -209,7 +209,7 @@ class WhisperConfig(PretrainedConfig):
|
|||
use_cache=True,
|
||||
is_encoder_decoder=True,
|
||||
activation_function="gelu",
|
||||
d_model=256,
|
||||
d_model=384,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
"""Converts a Whisper model in OpenAI format to Hugging Face format."""
|
||||
# Copyright 2022 The HuggingFace Inc. team and the OpenAI team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -14,6 +16,7 @@
|
|||
|
||||
import argparse
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
|
@ -90,7 +93,7 @@ def make_linear_from_emb(emb):
|
|||
return lin_layer
|
||||
|
||||
|
||||
def _download(url: str, root: str) -> bytes:
|
||||
def _download(url: str, root: str) -> io.BytesIO:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
filename = os.path.basename(url)
|
||||
|
||||
|
@ -103,7 +106,7 @@ def _download(url: str, root: str) -> bytes:
|
|||
if os.path.isfile(download_target):
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes
|
||||
return torch.load(io.BytesIO(model_bytes))
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
|
@ -125,12 +128,13 @@ def _download(url: str, root: str) -> bytes:
|
|||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
return model_bytes
|
||||
return torch.load(io.BytesIO(model_bytes))
|
||||
|
||||
|
||||
def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
|
||||
if ".pt" not in checkpoint_path:
|
||||
original_checkpoint = _download(_MODELS[checkpoint_path])
|
||||
root = os.path.dirname(pytorch_dump_folder_path) or "."
|
||||
original_checkpoint = _download(_MODELS[checkpoint_path], root)
|
||||
else:
|
||||
original_checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
dimensions = original_checkpoint["dims"]
|
||||
|
@ -151,7 +155,7 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
|
|||
encoder_layers=dimensions["n_audio_layer"],
|
||||
encoder_attention_heads=dimensions["n_audio_head"],
|
||||
decoder_layers=dimensions["n_text_layer"],
|
||||
decoder_attention_heads=dimensions["n_text_state"],
|
||||
decoder_attention_heads=dimensions["n_text_head"],
|
||||
max_source_positions=dimensions["n_audio_ctx"],
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue