Correct wav2vec2-bert inputs_to_logits_ratio (#28821)

* Correct wav2vec2-bert inputs_to_logits_ratio

* correct ratio

* correct ratio, clean asr pipeline

* refactor on one line
This commit is contained in:
Yoach Lacombe 2024-02-05 13:14:47 +00:00 committed by GitHub
parent 3f9f749325
commit 7addc9346c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 25 deletions

View File

@ -14,8 +14,6 @@
# limitations under the License.
""" Wav2Vec2Bert model configuration"""
import functools
import operator
from ...configuration_utils import PretrainedConfig
from ...utils import logging
@ -311,4 +309,7 @@ class Wav2Vec2BertConfig(PretrainedConfig):
@property
def inputs_to_logits_ratio(self):
return functools.reduce(operator.mul, self.conv_stride, 1)
ratio = self.feature_projection_input_dim * 2
if self.add_adapter:
ratio = ratio * (self.adapter_stride**self.num_adapter_layers)
return ratio

View File

@ -57,7 +57,7 @@ def rescale_stride(stride, ratio):
return new_strides
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None):
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None):
inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
for chunk_start_idx in range(0, inputs_len, step):
@ -73,13 +73,6 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
chunk_len = chunk.shape[0]
stride = (chunk_len, _stride_left, _stride_right)
if "input_features" in processed:
processed_len = processed["input_features"].shape[-1]
elif "input_values" in processed:
processed_len = processed["input_values"].shape[-1]
if processed_len != chunk.shape[-1] and rescale:
ratio = processed_len / chunk_len
stride = rescale_stride([stride], ratio)[0]
if chunk.shape[0] > _stride_left:
yield {"is_last": is_last, "stride": stride, **processed}
if is_last:
@ -436,10 +429,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if chunk_len < stride_left + stride_right:
raise ValueError("Chunk length must be superior to stride length")
rescale = self.type != "seq2seq_whisper"
# make sure that
for item in chunk_iter(
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
):
yield item
else:

View File

@ -1334,22 +1334,22 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
def test_chunk_iterator(self):
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
inputs = torch.arange(100).long()
ratio = 1
outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0, ratio))
outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0))
self.assertEqual(len(outs), 1)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
self.assertEqual([o["is_last"] for o in outs], [True])
# two chunks no stride
outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0, ratio))
outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(50, 0, 0), (50, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 50), (1, 50)])
self.assertEqual([o["is_last"] for o in outs], [False, True])
# two chunks incomplete last
outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0, ratio))
outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(80, 0, 0), (20, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)])
@ -1360,7 +1360,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
# This test is specifically crafted to trigger a bug if next chunk
# would be ignored by the fact that all the data would be
# contained in the strided left data.
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5, ratio))
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5))
self.assertEqual(len(outs), 1)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
@ -1373,25 +1373,24 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
"input_values"
]
ratio = 1
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10, ratio))
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)])
self.assertEqual([o["is_last"] for o in outs], [False, True])
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10, ratio))
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(80, 0, 10), (50, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 50)])
self.assertEqual([o["is_last"] for o in outs], [False, True])
outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0, ratio))
outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)])
outs = list(chunk_iter(inputs, feature_extractor, 36, 6, 6, ratio))
outs = list(chunk_iter(inputs, feature_extractor, 36, 6, 6))
self.assertEqual(len(outs), 4)
self.assertEqual([o["stride"] for o in outs], [(36, 0, 6), (36, 6, 6), (36, 6, 6), (28, 6, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 36), (1, 36), (1, 36), (1, 28)])
@ -1400,7 +1399,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
"input_values"
]
outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5, ratio))
outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5))
self.assertEqual(len(outs), 5)
self.assertEqual([o["stride"] for o in outs], [(30, 0, 5), (30, 5, 5), (30, 5, 5), (30, 5, 5), (20, 5, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 30), (1, 30), (1, 30), (1, 30), (1, 20)])