Avoid extra chunk in speech recognition (#29539)
This commit is contained in:
parent
a778108a3c
commit
c337d55988
|
@ -67,8 +67,7 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
|
|||
if dtype is not None:
|
||||
processed = processed.to(dtype=dtype)
|
||||
_stride_left = 0 if chunk_start_idx == 0 else stride_left
|
||||
# all right strides must be full, otherwise it is the last item
|
||||
is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len
|
||||
is_last = chunk_end_idx >= inputs_len
|
||||
_stride_right = 0 if is_last else stride_right
|
||||
|
||||
chunk_len = chunk.shape[0]
|
||||
|
|
|
@ -1569,10 +1569,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
"input_values"
|
||||
]
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10))
|
||||
self.assertEqual(len(outs), 2)
|
||||
self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
||||
self.assertEqual(len(outs), 1)
|
||||
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [True])
|
||||
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10))
|
||||
self.assertEqual(len(outs), 2)
|
||||
|
|
Loading…
Reference in New Issue