Add generate kwargs to `AutomaticSpeechRecognitionPipeline` (#20952)

* Add generate kwargs to AutomaticSpeechRecognitionPipeline

* Add test for generation kwargs
This commit is contained in:
bofeng huang 2022-12-31 07:13:28 +01:00 committed by GitHub
parent 9e6da0a7ed
commit 47c9b22d08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 17 deletions

View File

@ -24,6 +24,8 @@ from .base import ChunkPipeline
if TYPE_CHECKING:
from pyctcdecode import BeamSearchDecoderCTC
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
logger = logging.get_logger(__name__)
@ -169,8 +171,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
"""
def __init__(self, feature_extractor: Union["SequenceFeatureExtractor", str], *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
feature_extractor: Union["SequenceFeatureExtractor", str],
*,
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
**kwargs
):
super().__init__(**kwargs)
self.feature_extractor = feature_extractor
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
@ -178,9 +186,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
elif (
feature_extractor._processor_class
and feature_extractor._processor_class.endswith("WithLM")
and kwargs.get("decoder", None) is not None
and decoder is not None
):
self.decoder = kwargs["decoder"]
self.decoder = decoder
self.type = "ctc_with_lm"
else:
self.type = "ctc"
@ -221,6 +229,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
`timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ",
"timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
generate_kwargs (`dict`, *optional*):
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
Return:
`Dict`: A dictionary with the following keys:
@ -233,23 +247,43 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
"""
return super().__call__(inputs, **kwargs)
def _sanitize_parameters(self, **kwargs):
def _sanitize_parameters(
self,
chunk_length_s=None,
stride_length_s=None,
ignore_warning=None,
decoder_kwargs=None,
return_timestamps=None,
generate_kwargs=None,
max_new_tokens=None,
):
# No parameters on this pipeline right now
preprocess_params = {}
if "chunk_length_s" in kwargs:
preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"]
if "stride_length_s" in kwargs:
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
if "ignore_warning" in kwargs:
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
if chunk_length_s is not None:
preprocess_params["chunk_length_s"] = chunk_length_s
if stride_length_s is not None:
preprocess_params["stride_length_s"] = stride_length_s
if ignore_warning is not None:
preprocess_params["ignore_warning"] = ignore_warning
forward_params = {"generate_kwargs": {}}
if max_new_tokens is not None:
forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens
if generate_kwargs is not None:
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
raise ValueError(
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
" only 1 version"
)
forward_params["generate_kwargs"].update(generate_kwargs)
postprocess_params = {}
if "decoder_kwargs" in kwargs:
postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"]
if "return_timestamps" in kwargs:
postprocess_params["return_timestamps"] = kwargs["return_timestamps"]
if decoder_kwargs is not None:
postprocess_params["decoder_kwargs"] = decoder_kwargs
if return_timestamps is not None:
postprocess_params["return_timestamps"] = return_timestamps
return preprocess_params, {}, postprocess_params
return preprocess_params, forward_params, postprocess_params
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
if isinstance(inputs, str):
@ -351,7 +385,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
processed["stride"] = stride
yield {"is_last": True, **processed, **extra}
def _forward(self, model_inputs):
def _forward(self, model_inputs, generate_kwargs=None):
if generate_kwargs is None:
generate_kwargs = {}
is_last = model_inputs.pop("is_last")
if self.type == "seq2seq":
encoder = self.model.get_encoder()
@ -376,6 +413,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask,
**generate_kwargs,
)
out = {"tokens": tokens}

View File

@ -169,6 +169,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "あл ش 湯 清 ه ܬ া लᆨしث ल eか u w 全 u"})
@require_torch
def test_small_model_pt_seq2seq_gen_kwargs(self):
speech_recognizer = pipeline(
model="hf-internal-testing/tiny-random-speech-encoder-decoder",
framework="pt",
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform, max_new_tokens=10, generate_kwargs={"num_beams": 2})
self.assertEqual(output, {"text": "あл † γ ت ב オ 束 泣 足"})
@slow
@require_torch
@require_pyctcdecode