Correct wav2vec2-bert inputs_to_logits_ratio (#28821)
* Correct wav2vec2-bert inputs_to_logits_ratio * correct ratio * correct ratio, clean asr pipeline * refactor on one line
This commit is contained in:
parent
3f9f749325
commit
7addc9346c
|
@ -14,8 +14,6 @@
|
|||
# limitations under the License.
|
||||
""" Wav2Vec2Bert model configuration"""
|
||||
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
@ -311,4 +309,7 @@ class Wav2Vec2BertConfig(PretrainedConfig):
|
|||
|
||||
@property
|
||||
def inputs_to_logits_ratio(self):
|
||||
return functools.reduce(operator.mul, self.conv_stride, 1)
|
||||
ratio = self.feature_projection_input_dim * 2
|
||||
if self.add_adapter:
|
||||
ratio = ratio * (self.adapter_stride**self.num_adapter_layers)
|
||||
return ratio
|
||||
|
|
|
@ -57,7 +57,7 @@ def rescale_stride(stride, ratio):
|
|||
return new_strides
|
||||
|
||||
|
||||
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, rescale=True, dtype=None):
|
||||
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None):
|
||||
inputs_len = inputs.shape[0]
|
||||
step = chunk_len - stride_left - stride_right
|
||||
for chunk_start_idx in range(0, inputs_len, step):
|
||||
|
@ -73,13 +73,6 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
|
|||
|
||||
chunk_len = chunk.shape[0]
|
||||
stride = (chunk_len, _stride_left, _stride_right)
|
||||
if "input_features" in processed:
|
||||
processed_len = processed["input_features"].shape[-1]
|
||||
elif "input_values" in processed:
|
||||
processed_len = processed["input_values"].shape[-1]
|
||||
if processed_len != chunk.shape[-1] and rescale:
|
||||
ratio = processed_len / chunk_len
|
||||
stride = rescale_stride([stride], ratio)[0]
|
||||
if chunk.shape[0] > _stride_left:
|
||||
yield {"is_last": is_last, "stride": stride, **processed}
|
||||
if is_last:
|
||||
|
@ -436,10 +429,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||
if chunk_len < stride_left + stride_right:
|
||||
raise ValueError("Chunk length must be superior to stride length")
|
||||
|
||||
rescale = self.type != "seq2seq_whisper"
|
||||
# make sure that
|
||||
for item in chunk_iter(
|
||||
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype
|
||||
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
|
||||
):
|
||||
yield item
|
||||
else:
|
||||
|
|
|
@ -1334,22 +1334,22 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
def test_chunk_iterator(self):
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
inputs = torch.arange(100).long()
|
||||
ratio = 1
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0, ratio))
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 100, 0, 0))
|
||||
|
||||
self.assertEqual(len(outs), 1)
|
||||
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [True])
|
||||
|
||||
# two chunks no stride
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0, ratio))
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 50, 0, 0))
|
||||
self.assertEqual(len(outs), 2)
|
||||
self.assertEqual([o["stride"] for o in outs], [(50, 0, 0), (50, 0, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 50), (1, 50)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
||||
|
||||
# two chunks incomplete last
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0, ratio))
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 80, 0, 0))
|
||||
self.assertEqual(len(outs), 2)
|
||||
self.assertEqual([o["stride"] for o in outs], [(80, 0, 0), (20, 0, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 20)])
|
||||
|
@ -1360,7 +1360,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
# This test is specifically crafted to trigger a bug if next chunk
|
||||
# would be ignored by the fact that all the data would be
|
||||
# contained in the strided left data.
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5, ratio))
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 105, 5, 5))
|
||||
self.assertEqual(len(outs), 1)
|
||||
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
|
||||
|
@ -1373,25 +1373,24 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
|
||||
"input_values"
|
||||
]
|
||||
ratio = 1
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10, ratio))
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10))
|
||||
self.assertEqual(len(outs), 2)
|
||||
self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
||||
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10, ratio))
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10))
|
||||
self.assertEqual(len(outs), 2)
|
||||
self.assertEqual([o["stride"] for o in outs], [(80, 0, 10), (50, 20, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 80), (1, 50)])
|
||||
self.assertEqual([o["is_last"] for o in outs], [False, True])
|
||||
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0, ratio))
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 90, 20, 0))
|
||||
self.assertEqual(len(outs), 2)
|
||||
self.assertEqual([o["stride"] for o in outs], [(90, 0, 0), (30, 20, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 90), (1, 30)])
|
||||
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 36, 6, 6, ratio))
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 36, 6, 6))
|
||||
self.assertEqual(len(outs), 4)
|
||||
self.assertEqual([o["stride"] for o in outs], [(36, 0, 6), (36, 6, 6), (36, 6, 6), (28, 6, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 36), (1, 36), (1, 36), (1, 28)])
|
||||
|
@ -1400,7 +1399,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||
input_values = feature_extractor(inputs, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")[
|
||||
"input_values"
|
||||
]
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5, ratio))
|
||||
outs = list(chunk_iter(inputs, feature_extractor, 30, 5, 5))
|
||||
self.assertEqual(len(outs), 5)
|
||||
self.assertEqual([o["stride"] for o in outs], [(30, 0, 5), (30, 5, 5), (30, 5, 5), (30, 5, 5), (20, 5, 0)])
|
||||
self.assertEqual([o["input_values"].shape for o in outs], [(1, 30), (1, 30), (1, 30), (1, 30), (1, 20)])
|
||||
|
|
Loading…
Reference in New Issue