parent
0b4c3a1a53
commit
dbac8899fe
|
@ -290,7 +290,7 @@ jobs:
|
|||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt -y update && apt install -y libsndfile1-dev git
|
||||
apt -y update && apt install -y libsndfile1-dev git espeak-ng
|
||||
pip install --upgrade pip
|
||||
pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision]
|
||||
pip install https://github.com/kpu/kenlm/archive/master.zip
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
|
||||
import copy
|
||||
import glob
|
||||
import inspect
|
||||
import math
|
||||
import unittest
|
||||
|
@ -23,6 +24,7 @@ import numpy as np
|
|||
import pytest
|
||||
from datasets import load_dataset
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import Wav2Vec2Config, is_tf_available
|
||||
from transformers.file_utils import is_librosa_available, is_pyctcdecode_available
|
||||
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
|
||||
|
@ -485,8 +487,6 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
|
|||
@slow
|
||||
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").filter(
|
||||
|
@ -556,18 +556,17 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||
@require_pyctcdecode
|
||||
@require_librosa
|
||||
def test_wav2vec2_with_lm(self):
|
||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
|
||||
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
|
||||
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
|
||||
file_path = glob.glob(downloaded_folder + "/*")[0]
|
||||
sample = librosa.load(file_path, sr=16_000)[0]
|
||||
|
||||
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
||||
|
||||
input_values = processor(resampled_audio, return_tensors="tf").input_values
|
||||
input_values = processor(sample, return_tensors="tf").input_values
|
||||
|
||||
logits = model(input_values).logits
|
||||
|
||||
transcription = processor.batch_decode(logits.numpy()).text
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch WavLM model. """
|
||||
|
||||
import copy
|
||||
import math
|
||||
import unittest
|
||||
|
||||
|
@ -452,30 +451,9 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
||||
module.masked_spec_embed.data.fill_(3)
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
# as WavLM is not very precise
|
||||
@unittest.skip(reason="Feed forward chunking is not implemented for WavLM")
|
||||
def test_feed_forward_chunking(self):
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
torch.manual_seed(0)
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
config.chunk_size_feed_forward = 1
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-2))
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
|
@ -528,7 +506,7 @@ class WavLMModelIntegrationTest(unittest.TestCase):
|
|||
def test_inference_large(self):
|
||||
model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(torch_device)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"microsoft/wavlm-base-plus", return_attention_mask=True
|
||||
"microsoft/wavlm-large", return_attention_mask=True
|
||||
)
|
||||
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
@ -544,8 +522,9 @@ class WavLMModelIntegrationTest(unittest.TestCase):
|
|||
)
|
||||
|
||||
EXPECTED_HIDDEN_STATES_SLICE = torch.tensor(
|
||||
[[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]]
|
||||
[[[0.2122, 0.0500], [0.2118, 0.0563]], [[0.1353, 0.1818], [0.2453, 0.0595]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2))
|
||||
|
||||
def test_inference_diarization(self):
|
||||
|
|
Loading…
Reference in New Issue