Fix whisper doc (#19608)
* update feature extractor params * update attention mask handling * fix doc and pipeline test * add warning when skipping test * add whisper translation and transcription test * fix build doc test * Correct whisper processor * make fix copies * remove sample docstring as it does not fit whisper model * Update src/transformers/models/whisper/modeling_whisper.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix, doctests are passing * Nit * last nit Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
66dd80213c
commit
614f7d28a8
|
@ -32,13 +32,7 @@ from ...modeling_outputs import (
|
|||
Seq2SeqModelOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_whisper import WhisperConfig
|
||||
|
||||
|
||||
|
@ -46,8 +40,6 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_CONFIG_FOR_DOC = "WhisperConfig"
|
||||
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
|
||||
_PROCESSOR_FOR_DOC = "WhisperProcessor"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 2, 512]
|
||||
|
||||
|
||||
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
|
@ -1005,14 +997,7 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||
self.encoder._freeze_parameters()
|
||||
|
||||
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_PROCESSOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=Seq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
modality="audio",
|
||||
)
|
||||
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_features=None,
|
||||
|
@ -1029,7 +1014,25 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import WhisperFeatureExtractor, WhisperModel
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
|
||||
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
|
||||
>>> input_features = inputs.input_features
|
||||
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
|
||||
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
|
||||
>>> list(last_hidden_state.shape)
|
||||
[1, 2, 512]
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
|
|
Loading…
Reference in New Issue