Fix whisper doc (#19608)

* update feature extractor params

* update attention mask handling

* fix doc and pipeline test

* add warning when skipping test

* add whisper translation and transcription test

* fix build doc test

* Correct whisper processor

* make fix copies

* remove sample docstring as it does not fit whisper model

* Update src/transformers/models/whisper/modeling_whisper.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix, doctests are passing

* Nit

* last nit

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Arthur 2022-10-14 18:12:32 +02:00 committed by GitHub
parent 66dd80213c
commit 614f7d28a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 17 deletions

View File

@ -32,13 +32,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_whisper import WhisperConfig
@ -46,8 +40,6 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "WhisperConfig"
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
_PROCESSOR_FOR_DOC = "WhisperProcessor"
_EXPECTED_OUTPUT_SHAPE = [1, 2, 512]
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -1005,14 +997,7 @@ class WhisperModel(WhisperPreTrainedModel):
self.encoder._freeze_parameters()
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_EXPECTED_OUTPUT_SHAPE,
modality="audio",
)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features=None,
@ -1029,7 +1014,25 @@ class WhisperModel(WhisperPreTrainedModel):
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
Example:
```python
>>> import torch
>>> from transformers import WhisperFeatureExtractor, WhisperModel
>>> from datasets import load_dataset
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 512]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states