more fixes
This commit is contained in:
parent
fd1b7e2a6c
commit
861d691206
|
@ -1568,7 +1568,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
model.to(torch_device)
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(
|
||||
|
@ -1695,9 +1695,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
model.config.decoder_start_token_id = 50257
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
|
||||
transcript = processor.tokenizer.batch_decode(generated_ids)[0]
|
||||
|
@ -1717,9 +1716,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
model.to(torch_device)
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
|
||||
transcript = processor.tokenizer.decode(generated_ids[0])
|
||||
|
@ -1739,9 +1737,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
model.to(torch_device)
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
|
||||
|
@ -1764,9 +1761,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
|
||||
input_speech = next(iter(ds))["audio"]["array"]
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
|
||||
|
@ -1799,7 +1795,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features
|
||||
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
generated_ids = model.generate(input_features, max_length=20, task="translate")
|
||||
|
||||
# fmt: off
|
||||
|
@ -1835,9 +1832,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
model.to(torch_device)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
generated_ids = model.generate(input_features, max_length=20).to("cpu")
|
||||
|
||||
# fmt: off
|
||||
|
@ -1874,9 +1870,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
model.to(torch_device)
|
||||
|
||||
input_speech = np.concatenate(self._load_datasamples(4))
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
|
||||
generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu")
|
||||
|
||||
|
@ -1939,9 +1934,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
|
||||
generate_outputs = model.generate(
|
||||
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
|
||||
|
@ -1972,9 +1966,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
num_return_sequences = 2
|
||||
|
||||
input_speech = self._load_datasamples(num_samples)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
|
||||
generate_outputs = model.generate(
|
||||
input_features,
|
||||
|
@ -2000,8 +1993,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
|
||||
input_speech = self._load_datasamples(5)
|
||||
long_input_speech = np.concatenate(input_speech, dtype=np.float32)
|
||||
inputs = processor.feature_extractor(
|
||||
raw_speech=long_input_speech,
|
||||
inputs = processor(
|
||||
long_input_speech,
|
||||
return_tensors="pt",
|
||||
truncation=False, # False so the audio isn't truncated and whole audio is sent to the model
|
||||
return_attention_mask=True,
|
||||
|
@ -2051,7 +2044,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||
model.to(torch_device)
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(
|
||||
|
|
Loading…
Reference in New Issue