[Whisper] Block language/task args for English-only (#27322)

* [Whisper] Block language/task args for English-only

* Update src/transformers/models/whisper/modeling_whisper.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Sanchit Gandhi 2023-11-07 10:04:23 +00:00 committed by GitHub
parent 9beb2737d7
commit da7ea9a4e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 0 deletions

View File

@ -1841,6 +1841,22 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
else:
generation_config.return_timestamps = False
if is_multilingual is not None:
if not hasattr(generation_config, "is_multilingual"):
raise ValueError(
"The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
"to `generate`. Please update the generation config as per the instructions "
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
generation_config.is_multilingual = is_multilingual
if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
if task is not None or language is not None:
raise ValueError(
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
"multilingual, pass `is_multilingual=True` to generate, or update the generation config."
)
if language is not None:
if not hasattr(generation_config, "lang_to_id"):
raise ValueError(

View File

@ -852,6 +852,44 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
output_3 = speech_translator(filename)
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})
@slow
@require_torch
def test_whisper_language(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny.en",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
# 1. English-only model compatible with no language argument
output = speech_recognizer(filename)
self.assertEqual(
output,
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
)
# 2. English-only Whisper does not accept the language argument
with self.assertRaisesRegex(
ValueError,
"Cannot specify `task` or `langauge` for an English-only model. If the model is intended to be multilingual, "
"pass `is_multilingual=True` to generate, or update the generation config.",
):
_ = speech_recognizer(filename, generate_kwargs={"language": "en"})
# 3. Multilingual model accepts language argument
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
framework="pt",
)
output = speech_recognizer(filename, generate_kwargs={"language": "en"})
self.assertEqual(
output,
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
)
@slow
@require_torch
@require_torchaudio