diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index b4e99b6112..08204c51d8 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -32,13 +32,7 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_whisper import WhisperConfig @@ -46,8 +40,6 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "WhisperConfig" _CHECKPOINT_FOR_DOC = "openai/whisper-tiny" -_PROCESSOR_FOR_DOC = "WhisperProcessor" -_EXPECTED_OUTPUT_SHAPE = [1, 2, 512] WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [ @@ -1005,14 +997,7 @@ class WhisperModel(WhisperPreTrainedModel): self.encoder._freeze_parameters() @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - processor_class=_PROCESSOR_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=Seq2SeqModelOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_EXPECTED_OUTPUT_SHAPE, - modality="audio", - ) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_features=None, @@ -1029,7 +1014,25 @@ class WhisperModel(WhisperPreTrainedModel): output_hidden_states=None, return_dict=None, ): + r""" + Returns: + Example: + ```python + >>> import torch + >>> from transformers import WhisperFeatureExtractor, WhisperModel + >>> from datasets import load_dataset + + >>> model = WhisperModel.from_pretrained("openai/whisper-base") + >>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 2, 512] + ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states