From 861d69120667a99e583f36259582d55a76c12e66 Mon Sep 17 00:00:00 2001 From: sanchit-gandhi Date: Tue, 9 Apr 2024 22:47:59 +0100 Subject: [PATCH] more fixes --- tests/models/whisper/test_modeling_whisper.py | 51 ++++++++----------- 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 2dbbf2f212..bc25d15f2c 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1568,7 +1568,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): model.to(torch_device) input_speech = self._load_datasamples(1) feature_extractor = WhisperFeatureExtractor() - input_features = feature_extractor(input_speech, return_tensors="pt").input_features + input_features = feature_extractor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features with torch.no_grad(): logits = model( @@ -1695,9 +1695,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): model.config.decoder_start_token_id = 50257 input_speech = self._load_datasamples(1) - input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( - torch_device - ) + input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features + input_features = input_features.to(torch_device) generated_ids = model.generate(input_features, num_beams=5, max_length=20) transcript = processor.tokenizer.batch_decode(generated_ids)[0] @@ -1717,9 +1716,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): model.to(torch_device) input_speech = self._load_datasamples(1) - input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( - torch_device - ) + input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features + input_features = input_features.to(torch_device) generated_ids = model.generate(input_features, num_beams=5, max_length=20) transcript = processor.tokenizer.decode(generated_ids[0]) @@ -1739,9 +1737,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): model.to(torch_device) input_speech = self._load_datasamples(1) - input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( - torch_device - ) + input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features + input_features = input_features.to(torch_device) generated_ids = model.generate( input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe" @@ -1764,9 +1761,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000)) input_speech = next(iter(ds))["audio"]["array"] - input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( - torch_device - ) + input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features + input_features = input_features.to(torch_device) generated_ids = model.generate( input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe" @@ -1799,7 +1795,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large") input_speech = self._load_datasamples(4) - input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features + input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features + input_features = input_features.to(torch_device) generated_ids = model.generate(input_features, max_length=20, task="translate") # fmt: off @@ -1835,9 +1832,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): model.to(torch_device) input_speech = self._load_datasamples(4) - input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( - torch_device - ) + input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features + input_features = input_features.to(torch_device) generated_ids = model.generate(input_features, max_length=20).to("cpu") # fmt: off @@ -1874,9 +1870,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): model.to(torch_device) input_speech = np.concatenate(self._load_datasamples(4)) - input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( - torch_device - ) + input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features + input_features = input_features.to(torch_device) generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu") @@ -1939,9 +1934,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]] input_speech = self._load_datasamples(4) - input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( - torch_device - ) + input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features + input_features = input_features.to(torch_device) generate_outputs = model.generate( input_features, max_length=448, return_timestamps=True, return_token_timestamps=True @@ -1972,9 +1966,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): num_return_sequences = 2 input_speech = self._load_datasamples(num_samples) - input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( - torch_device - ) + input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features + input_features = input_features.to(torch_device) generate_outputs = model.generate( input_features, @@ -2000,8 +1993,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): input_speech = self._load_datasamples(5) long_input_speech = np.concatenate(input_speech, dtype=np.float32) - inputs = processor.feature_extractor( - raw_speech=long_input_speech, + inputs = processor( + long_input_speech, return_tensors="pt", truncation=False, # False so the audio isn't truncated and whole audio is sent to the model return_attention_mask=True, @@ -2051,7 +2044,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): model.to(torch_device) input_speech = self._load_datasamples(1) feature_extractor = WhisperFeatureExtractor() - input_features = feature_extractor(input_speech, return_tensors="pt").input_features + input_features = feature_extractor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features with torch.no_grad(): logits = model(