[Whisper] Block language/task args for English-only (#27322)
* [Whisper] Block language/task args for English-only * Update src/transformers/models/whisper/modeling_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
9beb2737d7
commit
da7ea9a4e3
|
@ -1841,6 +1841,22 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||
else:
|
||||
generation_config.return_timestamps = False
|
||||
|
||||
if is_multilingual is not None:
|
||||
if not hasattr(generation_config, "is_multilingual"):
|
||||
raise ValueError(
|
||||
"The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
|
||||
"to `generate`. Please update the generation config as per the instructions "
|
||||
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
|
||||
)
|
||||
generation_config.is_multilingual = is_multilingual
|
||||
|
||||
if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
|
||||
if task is not None or language is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
|
||||
"multilingual, pass `is_multilingual=True` to generate, or update the generation config."
|
||||
)
|
||||
|
||||
if language is not None:
|
||||
if not hasattr(generation_config, "lang_to_id"):
|
||||
raise ValueError(
|
||||
|
|
|
@ -852,6 +852,44 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
output_3 = speech_translator(filename)
|
||||
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_whisper_language(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="openai/whisper-tiny.en",
|
||||
framework="pt",
|
||||
)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
filename = ds[0]["file"]
|
||||
|
||||
# 1. English-only model compatible with no language argument
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(
|
||||
output,
|
||||
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
|
||||
)
|
||||
|
||||
# 2. English-only Whisper does not accept the language argument
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Cannot specify `task` or `langauge` for an English-only model. If the model is intended to be multilingual, "
|
||||
"pass `is_multilingual=True` to generate, or update the generation config.",
|
||||
):
|
||||
_ = speech_recognizer(filename, generate_kwargs={"language": "en"})
|
||||
|
||||
# 3. Multilingual model accepts language argument
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="openai/whisper-tiny",
|
||||
framework="pt",
|
||||
)
|
||||
output = speech_recognizer(filename, generate_kwargs={"language": "en"})
|
||||
self.assertEqual(
|
||||
output,
|
||||
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
|
|
Loading…
Reference in New Issue