[ASR pipeline] correct with lm pipeline (#15200)

* [ASR pipeline] correct with lm pipeline

* improve error
This commit is contained in:
Patrick von Platen 2022-01-18 15:36:22 +01:00 committed by GitHub
parent 1144d336b6
commit 497346d07e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 36 additions and 9 deletions

View File

@ -152,7 +152,7 @@ _deps = [
"tokenizers>=0.10.1,!=0.11.3",
"torch>=1.0",
"torchaudio",
"pyctcdecode>=0.2.0",
"pyctcdecode>=0.3.0",
"tqdm>=4.27",
"unidic>=1.0.2",
"unidic_lite>=1.0.7",

View File

@ -62,7 +62,7 @@ deps = {
"tokenizers": "tokenizers>=0.10.1,!=0.11.3",
"torch": "torch>=1.0",
"torchaudio": "torchaudio",
"pyctcdecode": "pyctcdecode>=0.2.0",
"pyctcdecode": "pyctcdecode>=0.3.0",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",

View File

@ -489,8 +489,9 @@ class FeatureExtractionMixin:
# make sure private name "_processor_class" is correctly
# saved as "processor_class"
if dictionary.get("_processor_class", None) is not None:
dictionary["processor_class"] = dictionary.pop("_processor_class")
_processor_class = dictionary.pop("_processor_class", None)
if _processor_class is not None:
dictionary["processor_class"] = _processor_class
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"

View File

@ -4,6 +4,7 @@
import io
import json
import os
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
@ -617,17 +618,16 @@ def pipeline(
and isinstance(model_name, str)
):
try:
import kenlm # to trigger `ImportError` if not installed
from pyctcdecode import BeamSearchDecoderCTC
language_model_glob = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
allow_regex = [language_model_glob, alphabet_filename]
decoder = BeamSearchDecoderCTC.load_from_hf_hub(
pretrained_model_name_or_path, allow_regex=allow_regex
)
decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_regex=allow_regex)
kwargs["decoder"] = decoder
except Exception as e:
except ImportError as e:
logger.warning(
"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import unittest
@ -42,8 +43,9 @@ class AutoFeatureExtractorTest(unittest.TestCase):
# remove feature_extractor_type to make sure config.json alone is enough to load feature processor locally
config_dict = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR).to_dict()
config_dict.pop("feature_extractor_type")
config = Wav2Vec2FeatureExtractor(config_dict)
config = Wav2Vec2FeatureExtractor(**config_dict)
# save in new folder
model_config.save_pretrained(tmpdirname)
@ -51,6 +53,10 @@ class AutoFeatureExtractorTest(unittest.TestCase):
config = AutoFeatureExtractor.from_pretrained(tmpdirname)
# make sure private variable is not incorrectly saved
dict_as_saved = json.loads(config.to_json_string())
self.assertTrue("_processor_class" not in dict_as_saved)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_file(self):

View File

@ -295,6 +295,26 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
@require_torch
@require_pyctcdecode
def test_with_lm_fast(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="hf-internal-testing/processor_with_lm",
framework="pt",
)
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]
n_repeats = 2
audio_tiled = np.tile(audio, n_repeats)
output = speech_recognizer([audio_tiled], batch_size=2)
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@slow
def test_chunking(self):