[Whisper] Computing features on GPU in batch mode for whisper feature extractor. (#29900)
* add _torch_extract_fbank_features_batch function in feature_extractor_whisper * reformat feature_extraction_whisper.py file * handle batching in single function * add gpu test & doc * add batch test & device in each __call__ * add device arg in doc string --------- Co-authored-by: vaibhav.aggarwal <vaibhav.aggarwal@sprinklr.com>
This commit is contained in:
parent
1fc34aa666
commit
1ed93be48a
|
@ -94,11 +94,19 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||||
mel_scale="slaney",
|
mel_scale="slaney",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
|
def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
|
Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
|
||||||
implementation with 1e-5 tolerance.
|
implementation with 1e-5 tolerance.
|
||||||
"""
|
"""
|
||||||
|
if device != "cpu":
|
||||||
|
raise ValueError(
|
||||||
|
f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
|
||||||
|
"devices requires torch, which is not installed. Either set `device='cpu'`, or "
|
||||||
|
"install torch according to the official instructions: https://pytorch.org/get-started/locally/"
|
||||||
|
)
|
||||||
|
log_spec_batch = []
|
||||||
|
for waveform in waveform_batch:
|
||||||
log_spec = spectrogram(
|
log_spec = spectrogram(
|
||||||
waveform,
|
waveform,
|
||||||
window_function(self.n_fft, "hann"),
|
window_function(self.n_fft, "hann"),
|
||||||
|
@ -111,24 +119,38 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||||
log_spec = log_spec[:, :-1]
|
log_spec = log_spec[:, :-1]
|
||||||
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
log_spec = (log_spec + 4.0) / 4.0
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
return log_spec
|
log_spec_batch.append(log_spec)
|
||||||
|
log_spec_batch = np.array(log_spec_batch)
|
||||||
|
return log_spec_batch
|
||||||
|
|
||||||
def _torch_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
|
def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute the log-mel spectrogram of the provided audio using the PyTorch STFT implementation.
|
Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
|
||||||
|
yielding results similar to cpu computing with 1e-5 tolerance.
|
||||||
"""
|
"""
|
||||||
waveform = torch.from_numpy(waveform).type(torch.float32)
|
waveform = torch.from_numpy(waveform).type(torch.float32)
|
||||||
|
|
||||||
window = torch.hann_window(self.n_fft)
|
window = torch.hann_window(self.n_fft)
|
||||||
|
if device != "cpu":
|
||||||
|
waveform = waveform.to(device)
|
||||||
|
window = window.to(device)
|
||||||
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
|
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
|
||||||
magnitudes = stft[..., :-1].abs() ** 2
|
magnitudes = stft[..., :-1].abs() ** 2
|
||||||
|
|
||||||
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
||||||
|
if device != "cpu":
|
||||||
|
mel_filters = mel_filters.to(device)
|
||||||
mel_spec = mel_filters.T @ magnitudes
|
mel_spec = mel_filters.T @ magnitudes
|
||||||
|
|
||||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||||
|
if waveform.dim() == 2:
|
||||||
|
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
|
||||||
|
log_spec = torch.maximum(log_spec, max_val - 8.0)
|
||||||
|
else:
|
||||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
log_spec = (log_spec + 4.0) / 4.0
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
|
if device != "cpu":
|
||||||
|
log_spec = log_spec.detach().cpu()
|
||||||
return log_spec.numpy()
|
return log_spec.numpy()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -165,6 +187,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
sampling_rate: Optional[int] = None,
|
sampling_rate: Optional[int] = None,
|
||||||
do_normalize: Optional[bool] = None,
|
do_normalize: Optional[bool] = None,
|
||||||
|
device: Optional[str] = "cpu",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
"""
|
"""
|
||||||
|
@ -211,6 +234,9 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||||
do_normalize (`bool`, *optional*, defaults to `False`):
|
do_normalize (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
|
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
|
||||||
improve the performance of the model.
|
improve the performance of the model.
|
||||||
|
device (`str`, *optional*, defaults to `'cpu'`):
|
||||||
|
Specifies the device for computation of the log-mel spectrogram of audio signals in the
|
||||||
|
`_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if sampling_rate is not None:
|
if sampling_rate is not None:
|
||||||
|
@ -272,7 +298,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||||
extract_fbank_features = (
|
extract_fbank_features = (
|
||||||
self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
|
self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
|
||||||
)
|
)
|
||||||
input_features = [extract_fbank_features(waveform) for waveform in input_features[0]]
|
input_features = extract_fbank_features(input_features[0], device)
|
||||||
|
|
||||||
if isinstance(input_features[0], List):
|
if isinstance(input_features[0], List):
|
||||||
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
||||||
|
|
|
@ -24,7 +24,7 @@ import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers import WhisperFeatureExtractor
|
from transformers import WhisperFeatureExtractor
|
||||||
from transformers.testing_utils import check_json_file_has_correct_format, require_torch
|
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torch_gpu
|
||||||
from transformers.utils.import_utils import is_torch_available
|
from transformers.utils.import_utils import is_torch_available
|
||||||
|
|
||||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||||
|
@ -207,6 +207,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||||
|
|
||||||
return [x["array"] for x in speech_samples]
|
return [x["array"] for x in speech_samples]
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_torch_integration(self):
|
def test_torch_integration(self):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
@ -223,6 +224,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||||
input_speech = self._load_datasamples(1)
|
input_speech = self._load_datasamples(1)
|
||||||
feature_extractor = WhisperFeatureExtractor()
|
feature_extractor = WhisperFeatureExtractor()
|
||||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||||
|
|
||||||
self.assertEqual(input_features.shape, (1, 80, 3000))
|
self.assertEqual(input_features.shape, (1, 80, 3000))
|
||||||
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||||
|
|
||||||
|
@ -253,3 +255,37 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||||
|
|
||||||
self.assertTrue(np.all(np.mean(audio) < 1e-3))
|
self.assertTrue(np.all(np.mean(audio) < 1e-3))
|
||||||
self.assertTrue(np.all(np.abs(np.var(audio) - 1) < 1e-3))
|
self.assertTrue(np.all(np.abs(np.var(audio) - 1) < 1e-3))
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_torch
|
||||||
|
def test_torch_integration_batch(self):
|
||||||
|
# fmt: off
|
||||||
|
EXPECTED_INPUT_FEATURES = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
|
||||||
|
0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
|
||||||
|
0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
|
||||||
|
-0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
|
||||||
|
],
|
||||||
|
[
|
||||||
|
-0.4696, -0.0751, 0.0276, -0.0312, -0.0540, -0.0383, 0.1295, 0.0568,
|
||||||
|
-0.2071, -0.0548, 0.0389, -0.0316, -0.2346, -0.1068, -0.0322, 0.0475,
|
||||||
|
-0.1709, -0.0041, 0.0872, 0.0537, 0.0075, -0.0392, 0.0371, 0.0189,
|
||||||
|
-0.1522, -0.0270, 0.0744, 0.0738, -0.0245, -0.0667
|
||||||
|
],
|
||||||
|
[
|
||||||
|
-0.2337, -0.0060, -0.0063, -0.2353, -0.0431, 0.1102, -0.1492, -0.0292,
|
||||||
|
0.0787, -0.0608, 0.0143, 0.0582, 0.0072, 0.0101, -0.0444, -0.1701,
|
||||||
|
-0.0064, -0.0027, -0.0826, -0.0730, -0.0099, -0.0762, -0.0170, 0.0446,
|
||||||
|
-0.1153, 0.0960, -0.0361, 0.0652, 0.1207, 0.0277
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
input_speech = self._load_datasamples(3)
|
||||||
|
feature_extractor = WhisperFeatureExtractor()
|
||||||
|
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||||
|
self.assertEqual(input_features.shape, (3, 80, 3000))
|
||||||
|
self.assertTrue(torch.allclose(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||||
|
|
Loading…
Reference in New Issue