[Whisper] Computing features on GPU in batch mode for whisper feature extractor. (#29900)

* add _torch_extract_fbank_features_batch function in feature_extractor_whisper

* reformat feature_extraction_whisper.py file

* handle batching in single function

* add gpu test & doc

* add batch test & device in each __call__

* add device arg in doc string

---------

Co-authored-by: vaibhav.aggarwal <vaibhav.aggarwal@sprinklr.com>
This commit is contained in:
vaibhavagg303 2024-04-08 14:06:25 +05:30 committed by GitHub
parent 1fc34aa666
commit 1ed93be48a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 81 additions and 19 deletions

View File

@ -94,41 +94,63 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
mel_scale="slaney",
)
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> np.ndarray:
"""
Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
implementation with 1e-5 tolerance.
"""
log_spec = spectrogram(
waveform,
window_function(self.n_fft, "hann"),
frame_length=self.n_fft,
hop_length=self.hop_length,
power=2.0,
mel_filters=self.mel_filters,
log_mel="log10",
)
log_spec = log_spec[:, :-1]
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
if device != "cpu":
raise ValueError(
f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
"devices requires torch, which is not installed. Either set `device='cpu'`, or "
"install torch according to the official instructions: https://pytorch.org/get-started/locally/"
)
log_spec_batch = []
for waveform in waveform_batch:
log_spec = spectrogram(
waveform,
window_function(self.n_fft, "hann"),
frame_length=self.n_fft,
hop_length=self.hop_length,
power=2.0,
mel_filters=self.mel_filters,
log_mel="log10",
)
log_spec = log_spec[:, :-1]
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
log_spec_batch.append(log_spec)
log_spec_batch = np.array(log_spec_batch)
return log_spec_batch
def _torch_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") -> np.ndarray:
"""
Compute the log-mel spectrogram of the provided audio using the PyTorch STFT implementation.
Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
yielding results similar to cpu computing with 1e-5 tolerance.
"""
waveform = torch.from_numpy(waveform).type(torch.float32)
window = torch.hann_window(self.n_fft)
if device != "cpu":
waveform = waveform.to(device)
window = window.to(device)
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
if device != "cpu":
mel_filters = mel_filters.to(device)
mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
if waveform.dim() == 2:
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
log_spec = torch.maximum(log_spec, max_val - 8.0)
else:
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
if device != "cpu":
log_spec = log_spec.detach().cpu()
return log_spec.numpy()
@staticmethod
@ -165,6 +187,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
do_normalize: Optional[bool] = None,
device: Optional[str] = "cpu",
**kwargs,
) -> BatchFeature:
"""
@ -211,6 +234,9 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
do_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
improve the performance of the model.
device (`str`, *optional*, defaults to `'cpu'`):
Specifies the device for computation of the log-mel spectrogram of audio signals in the
`_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
"""
if sampling_rate is not None:
@ -272,7 +298,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
extract_fbank_features = (
self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
)
input_features = [extract_fbank_features(waveform) for waveform in input_features[0]]
input_features = extract_fbank_features(input_features[0], device)
if isinstance(input_features[0], List):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]

View File

@ -24,7 +24,7 @@ import numpy as np
from datasets import load_dataset
from transformers import WhisperFeatureExtractor
from transformers.testing_utils import check_json_file_has_correct_format, require_torch
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torch_gpu
from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
@ -207,6 +207,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
return [x["array"] for x in speech_samples]
@require_torch_gpu
@require_torch
def test_torch_integration(self):
# fmt: off
@ -223,6 +224,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
input_speech = self._load_datasamples(1)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEqual(input_features.shape, (1, 80, 3000))
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
@ -253,3 +255,37 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
self.assertTrue(np.all(np.mean(audio) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(audio) - 1) < 1e-3))
@require_torch_gpu
@require_torch
def test_torch_integration_batch(self):
# fmt: off
EXPECTED_INPUT_FEATURES = torch.tensor(
[
[
0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
-0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
],
[
-0.4696, -0.0751, 0.0276, -0.0312, -0.0540, -0.0383, 0.1295, 0.0568,
-0.2071, -0.0548, 0.0389, -0.0316, -0.2346, -0.1068, -0.0322, 0.0475,
-0.1709, -0.0041, 0.0872, 0.0537, 0.0075, -0.0392, 0.0371, 0.0189,
-0.1522, -0.0270, 0.0744, 0.0738, -0.0245, -0.0667
],
[
-0.2337, -0.0060, -0.0063, -0.2353, -0.0431, 0.1102, -0.1492, -0.0292,
0.0787, -0.0608, 0.0143, 0.0582, 0.0072, 0.0101, -0.0444, -0.1701,
-0.0064, -0.0027, -0.0826, -0.0730, -0.0099, -0.0762, -0.0170, 0.0446,
-0.1153, 0.0960, -0.0361, 0.0652, 0.1207, 0.0277
]
]
)
# fmt: on
input_speech = self._load_datasamples(3)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEqual(input_features.shape, (3, 80, 3000))
self.assertTrue(torch.allclose(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))