return attention mask in int32 (#13543)

This commit is contained in:
Patrick von Platen 2021-09-13 14:02:23 +02:00 committed by GitHub
parent 149c833b75
commit 5c14fceac0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -240,12 +240,12 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
# Utterance-level cepstral mean and variance normalization
if self.do_ceptral_normalize:
attention_mask = (
np.array(attention_mask, dtype=np.bool)
np.array(attention_mask, dtype=np.int32)
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
else None
)

View File

@ -86,7 +86,7 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
Every array in the list is normalized to have zero mean and unit variance
"""
if attention_mask is not None:
attention_mask = np.array(attention_mask, np.bool)
attention_mask = np.array(attention_mask, np.int32)
normed_input_values = []
for vector, length in zip(input_values, attention_mask.sum(-1)):
@ -216,7 +216,7 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
# convert attention_mask to correct format
attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
# zero-mean and unit-variance normalization
if self.do_normalize: