Avoid extra chunk in speech recognition (#29539)

This commit is contained in:
Jonatan Kłosko 2024-05-22 15:07:51 +02:00 committed by Ita Zaporozhets
parent a778108a3c
commit c337d55988
2 changed files with 5 additions and 6 deletions

View File

@ -67,8 +67,7 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
if dtype is not None:
processed = processed.to(dtype=dtype)
_stride_left = 0 if chunk_start_idx == 0 else stride_left
# all right strides must be full, otherwise it is the last item
is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len
is_last = chunk_end_idx >= inputs_len
_stride_right = 0 if is_last else stride_right
chunk_len = chunk.shape[0]

View File

@ -1569,10 +1569,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
"input_values"
]
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)])
self.assertEqual([o["is_last"] for o in outs], [False, True])
self.assertEqual(len(outs), 1)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
self.assertEqual([o["is_last"] for o in outs], [True])
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10))
self.assertEqual(len(outs), 2)