[Tests] Correct Wav2Vec2 & WavLM tests (#15015)

* up

* up

* up
This commit is contained in:
Patrick von Platen 2022-01-03 20:19:04 +01:00 committed by GitHub
parent 0b4c3a1a53
commit dbac8899fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 35 deletions

View File

@ -290,7 +290,7 @@ jobs:
- name: Install dependencies
run: |
apt -y update && apt install -y libsndfile1-dev git
apt -y update && apt install -y libsndfile1-dev git espeak-ng
pip install --upgrade pip
pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision]
pip install https://github.com/kpu/kenlm/archive/master.zip

View File

@ -15,6 +15,7 @@
import copy
import glob
import inspect
import math
import unittest
@ -23,6 +24,7 @@ import numpy as np
import pytest
from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import Wav2Vec2Config, is_tf_available
from transformers.file_utils import is_librosa_available, is_pyctcdecode_available
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
@ -485,8 +487,6 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
@slow
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").filter(
@ -556,18 +556,17 @@ class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
@require_pyctcdecode
@require_librosa
def test_wav2vec2_with_lm(self):
ds = load_dataset("common_voice", "es", split="test", streaming=True)
sample = next(iter(ds))
resampled_audio = librosa.resample(sample["audio"]["array"], 48_000, 16_000)
downloaded_folder = snapshot_download("patrickvonplaten/common_voice_es_sample")
file_path = glob.glob(downloaded_folder + "/*")[0]
sample = librosa.load(file_path, sr=16_000)[0]
model = TFWav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
input_values = processor(resampled_audio, return_tensors="tf").input_values
input_values = processor(sample, return_tensors="tf").input_values
logits = model(input_values).logits
transcription = processor.batch_decode(logits.numpy()).text
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")

View File

@ -14,7 +14,6 @@
# limitations under the License.
""" Testing suite for the PyTorch WavLM model. """
import copy
import math
import unittest
@ -452,30 +451,9 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
module.masked_spec_embed.data.fill_(3)
# overwrite from test_modeling_common
# as WavLM is not very precise
@unittest.skip(reason="Feed forward chunking is not implemented for WavLM")
def test_feed_forward_chunking(self):
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
torch.manual_seed(0)
config.chunk_size_feed_forward = 1
model = model_class(config)
model.to(torch_device)
model.eval()
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-2))
pass
@slow
def test_model_from_pretrained(self):
@ -528,7 +506,7 @@ class WavLMModelIntegrationTest(unittest.TestCase):
def test_inference_large(self):
model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"microsoft/wavlm-base-plus", return_attention_mask=True
"microsoft/wavlm-large", return_attention_mask=True
)
input_speech = self._load_datasamples(2)
@ -544,8 +522,9 @@ class WavLMModelIntegrationTest(unittest.TestCase):
)
EXPECTED_HIDDEN_STATES_SLICE = torch.tensor(
[[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]]
[[[0.2122, 0.0500], [0.2118, 0.0563]], [[0.1353, 0.1818], [0.2453, 0.0595]]]
)
self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=5e-2))
def test_inference_diarization(self):