Add generate kwargs to `AutomaticSpeechRecognitionPipeline` (#20952)
* Add generate kwargs to AutomaticSpeechRecognitionPipeline * Add test for generation kwargs
This commit is contained in:
parent
9e6da0a7ed
commit
47c9b22d08
|
@ -24,6 +24,8 @@ from .base import ChunkPipeline
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -169,8 +171,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, feature_extractor: Union["SequenceFeatureExtractor", str], *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
feature_extractor: Union["SequenceFeatureExtractor", str],
|
||||
*,
|
||||
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.feature_extractor = feature_extractor
|
||||
|
||||
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
||||
|
@ -178,9 +186,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
elif (
|
||||
feature_extractor._processor_class
|
||||
and feature_extractor._processor_class.endswith("WithLM")
|
||||
and kwargs.get("decoder", None) is not None
|
||||
and decoder is not None
|
||||
):
|
||||
self.decoder = kwargs["decoder"]
|
||||
self.decoder = decoder
|
||||
self.type = "ctc_with_lm"
|
||||
else:
|
||||
self.type = "ctc"
|
||||
|
@ -221,6 +229,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
`timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ",
|
||||
"timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model
|
||||
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
|
||||
generate_kwargs (`dict`, *optional*):
|
||||
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
||||
complete overview of generate, check the [following
|
||||
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
|
||||
max_new_tokens (`int`, *optional*):
|
||||
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
||||
|
||||
Return:
|
||||
`Dict`: A dictionary with the following keys:
|
||||
|
@ -233,23 +247,43 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
"""
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
def _sanitize_parameters(
|
||||
self,
|
||||
chunk_length_s=None,
|
||||
stride_length_s=None,
|
||||
ignore_warning=None,
|
||||
decoder_kwargs=None,
|
||||
return_timestamps=None,
|
||||
generate_kwargs=None,
|
||||
max_new_tokens=None,
|
||||
):
|
||||
# No parameters on this pipeline right now
|
||||
preprocess_params = {}
|
||||
if "chunk_length_s" in kwargs:
|
||||
preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"]
|
||||
if "stride_length_s" in kwargs:
|
||||
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
|
||||
if "ignore_warning" in kwargs:
|
||||
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
|
||||
if chunk_length_s is not None:
|
||||
preprocess_params["chunk_length_s"] = chunk_length_s
|
||||
if stride_length_s is not None:
|
||||
preprocess_params["stride_length_s"] = stride_length_s
|
||||
if ignore_warning is not None:
|
||||
preprocess_params["ignore_warning"] = ignore_warning
|
||||
|
||||
forward_params = {"generate_kwargs": {}}
|
||||
if max_new_tokens is not None:
|
||||
forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens
|
||||
if generate_kwargs is not None:
|
||||
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
|
||||
raise ValueError(
|
||||
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
|
||||
" only 1 version"
|
||||
)
|
||||
forward_params["generate_kwargs"].update(generate_kwargs)
|
||||
|
||||
postprocess_params = {}
|
||||
if "decoder_kwargs" in kwargs:
|
||||
postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"]
|
||||
if "return_timestamps" in kwargs:
|
||||
postprocess_params["return_timestamps"] = kwargs["return_timestamps"]
|
||||
if decoder_kwargs is not None:
|
||||
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
||||
if return_timestamps is not None:
|
||||
postprocess_params["return_timestamps"] = return_timestamps
|
||||
|
||||
return preprocess_params, {}, postprocess_params
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
|
||||
if isinstance(inputs, str):
|
||||
|
@ -351,7 +385,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
processed["stride"] = stride
|
||||
yield {"is_last": True, **processed, **extra}
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
def _forward(self, model_inputs, generate_kwargs=None):
|
||||
if generate_kwargs is None:
|
||||
generate_kwargs = {}
|
||||
|
||||
is_last = model_inputs.pop("is_last")
|
||||
if self.type == "seq2seq":
|
||||
encoder = self.model.get_encoder()
|
||||
|
@ -376,6 +413,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
tokens = self.model.generate(
|
||||
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
|
||||
attention_mask=attention_mask,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
out = {"tokens": tokens}
|
||||
|
|
|
@ -169,6 +169,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||
output = speech_recognizer(waveform)
|
||||
self.assertEqual(output, {"text": "あл ش 湯 清 ه ܬ া लᆨしث ल eか u w 全 u"})
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_seq2seq_gen_kwargs(self):
|
||||
speech_recognizer = pipeline(
|
||||
model="hf-internal-testing/tiny-random-speech-encoder-decoder",
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||
output = speech_recognizer(waveform, max_new_tokens=10, generate_kwargs={"num_beams": 2})
|
||||
self.assertEqual(output, {"text": "あл † γ ت ב オ 束 泣 足"})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_pyctcdecode
|
||||
|
|
Loading…
Reference in New Issue