return attention mask in int32 (#13543)
This commit is contained in:
parent
149c833b75
commit
5c14fceac0
|
@ -240,12 +240,12 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
|||
|
||||
attention_mask = padded_inputs.get("attention_mask")
|
||||
if attention_mask is not None:
|
||||
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
|
||||
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
|
||||
|
||||
# Utterance-level cepstral mean and variance normalization
|
||||
if self.do_ceptral_normalize:
|
||||
attention_mask = (
|
||||
np.array(attention_mask, dtype=np.bool)
|
||||
np.array(attention_mask, dtype=np.int32)
|
||||
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
|
||||
else None
|
||||
)
|
||||
|
|
|
@ -86,7 +86,7 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
|||
Every array in the list is normalized to have zero mean and unit variance
|
||||
"""
|
||||
if attention_mask is not None:
|
||||
attention_mask = np.array(attention_mask, np.bool)
|
||||
attention_mask = np.array(attention_mask, np.int32)
|
||||
normed_input_values = []
|
||||
|
||||
for vector, length in zip(input_values, attention_mask.sum(-1)):
|
||||
|
@ -216,7 +216,7 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
|||
# convert attention_mask to correct format
|
||||
attention_mask = padded_inputs.get("attention_mask")
|
||||
if attention_mask is not None:
|
||||
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]
|
||||
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
|
||||
|
||||
# zero-mean and unit-variance normalization
|
||||
if self.do_normalize:
|
||||
|
|
Loading…
Reference in New Issue