[ASR pipeline] correct with lm pipeline (#15200)
* [ASR pipeline] correct with lm pipeline * improve error
This commit is contained in:
parent
1144d336b6
commit
497346d07e
2
setup.py
2
setup.py
|
@ -152,7 +152,7 @@ _deps = [
|
|||
"tokenizers>=0.10.1,!=0.11.3",
|
||||
"torch>=1.0",
|
||||
"torchaudio",
|
||||
"pyctcdecode>=0.2.0",
|
||||
"pyctcdecode>=0.3.0",
|
||||
"tqdm>=4.27",
|
||||
"unidic>=1.0.2",
|
||||
"unidic_lite>=1.0.7",
|
||||
|
|
|
@ -62,7 +62,7 @@ deps = {
|
|||
"tokenizers": "tokenizers>=0.10.1,!=0.11.3",
|
||||
"torch": "torch>=1.0",
|
||||
"torchaudio": "torchaudio",
|
||||
"pyctcdecode": "pyctcdecode>=0.2.0",
|
||||
"pyctcdecode": "pyctcdecode>=0.3.0",
|
||||
"tqdm": "tqdm>=4.27",
|
||||
"unidic": "unidic>=1.0.2",
|
||||
"unidic_lite": "unidic_lite>=1.0.7",
|
||||
|
|
|
@ -489,8 +489,9 @@ class FeatureExtractionMixin:
|
|||
|
||||
# make sure private name "_processor_class" is correctly
|
||||
# saved as "processor_class"
|
||||
if dictionary.get("_processor_class", None) is not None:
|
||||
dictionary["processor_class"] = dictionary.pop("_processor_class")
|
||||
_processor_class = dictionary.pop("_processor_class", None)
|
||||
if _processor_class is not None:
|
||||
dictionary["processor_class"] = _processor_class
|
||||
|
||||
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
|
@ -617,17 +618,16 @@ def pipeline(
|
|||
and isinstance(model_name, str)
|
||||
):
|
||||
try:
|
||||
import kenlm # to trigger `ImportError` if not installed
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
|
||||
language_model_glob = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
|
||||
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
|
||||
allow_regex = [language_model_glob, alphabet_filename]
|
||||
|
||||
decoder = BeamSearchDecoderCTC.load_from_hf_hub(
|
||||
pretrained_model_name_or_path, allow_regex=allow_regex
|
||||
)
|
||||
decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_regex=allow_regex)
|
||||
kwargs["decoder"] = decoder
|
||||
except Exception as e:
|
||||
except ImportError as e:
|
||||
logger.warning(
|
||||
"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
|
||||
)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
@ -42,8 +43,9 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||
|
||||
# remove feature_extractor_type to make sure config.json alone is enough to load feature processor locally
|
||||
config_dict = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR).to_dict()
|
||||
|
||||
config_dict.pop("feature_extractor_type")
|
||||
config = Wav2Vec2FeatureExtractor(config_dict)
|
||||
config = Wav2Vec2FeatureExtractor(**config_dict)
|
||||
|
||||
# save in new folder
|
||||
model_config.save_pretrained(tmpdirname)
|
||||
|
@ -51,6 +53,10 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||
|
||||
config = AutoFeatureExtractor.from_pretrained(tmpdirname)
|
||||
|
||||
# make sure private variable is not incorrectly saved
|
||||
dict_as_saved = json.loads(config.to_json_string())
|
||||
self.assertTrue("_processor_class" not in dict_as_saved)
|
||||
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_feature_extractor_from_local_file(self):
|
||||
|
|
|
@ -295,6 +295,26 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
|
||||
|
||||
@require_torch
|
||||
@require_pyctcdecode
|
||||
def test_with_lm_fast(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="hf-internal-testing/processor_with_lm",
|
||||
framework="pt",
|
||||
)
|
||||
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
audio = ds[40]["audio"]["array"]
|
||||
|
||||
n_repeats = 2
|
||||
audio_tiled = np.tile(audio, n_repeats)
|
||||
output = speech_recognizer([audio_tiled], batch_size=2)
|
||||
|
||||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "<s> <s")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_chunking(self):
|
||||
|
|
Loading…
Reference in New Issue