[Wav2Vec2] Improve Tokenizer & Model for batched inference (#10117)
* save intermediate * finish batch the same as fairseq * add normalization * fix batched input * add better comment * Update src/transformers/models/wav2vec2/modeling_wav2vec2.py * add nice docstring * add tokenizer tests * make all slow tests pass * finish PR * correct import
This commit is contained in:
parent
2f3b5f4dcc
commit
495c157d6f
|
@ -36,7 +36,10 @@ logger = logging.get_logger(__name__)
|
|||
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/wav2vec2-base-960h"
|
||||
"facebook/wav2vec2-base-960h",
|
||||
"facebook/wav2vec2-large-960h",
|
||||
"facebook/wav2vec2-large-960h-lv60",
|
||||
"facebook/wav2vec2-large-960h-lv60-self",
|
||||
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
|
||||
]
|
||||
|
||||
|
@ -191,7 +194,6 @@ class Wav2Vec2FeatureProjection(nn.Module):
|
|||
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.projection(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
@ -387,9 +389,11 @@ class Wav2Vec2EncoderLayer(nn.Module):
|
|||
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states, output_attentions=False):
|
||||
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
||||
attn_residual = hidden_states
|
||||
hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions)
|
||||
hidden_states, attn_weights, _ = self.attention(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = attn_residual + hidden_states
|
||||
|
||||
|
@ -414,10 +418,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
|||
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states, output_attentions=False):
|
||||
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
||||
attn_residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions)
|
||||
hidden_states, attn_weights, _ = self.attention(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = attn_residual + hidden_states
|
||||
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
||||
|
@ -438,6 +444,7 @@ class Wav2Vec2Encoder(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
|
@ -445,6 +452,16 @@ class Wav2Vec2Encoder(nn.Module):
|
|||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
if attention_mask is not None:
|
||||
# make sure padded tokens output 0
|
||||
hidden_states[~attention_mask] = 0.0
|
||||
|
||||
# extend attention_mask
|
||||
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.expand(
|
||||
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
||||
)
|
||||
|
||||
position_embeddings = self.pos_conv_embed(hidden_states)
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
@ -454,7 +471,9 @@ class Wav2Vec2Encoder(nn.Module):
|
|||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions)
|
||||
hidden_states, attn_weights = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (attn_weights,)
|
||||
|
@ -486,6 +505,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
|
@ -493,6 +513,16 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
|||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
if attention_mask is not None:
|
||||
# make sure padded tokens are not attended to
|
||||
hidden_states[~attention_mask] = 0
|
||||
|
||||
# extend attention_mask
|
||||
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.expand(
|
||||
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
||||
)
|
||||
|
||||
position_embeddings = self.pos_conv_embed(hidden_states)
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
@ -501,7 +531,9 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
|||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions)
|
||||
hidden_states, attn_weights = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (attn_weights,)
|
||||
|
@ -544,6 +576,21 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
"""
|
||||
|
||||
def _conv_out_length(input_length, kernel_size, stride):
|
||||
# 1D convolutional layer output length formula taken
|
||||
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||
return torch.floor((input_length - kernel_size) / stride + 1)
|
||||
|
||||
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
||||
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
||||
|
||||
return input_lengths.to(torch.long)
|
||||
|
||||
|
||||
WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
|
||||
|
@ -572,6 +619,24 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
|
|||
soundfile`). To prepare the array into `input_values`, the :class:`~transformers.Wav2Vec2Tokenizer` should
|
||||
be used for padding and conversion into a tensor of type `torch.FloatTensor`. See
|
||||
:meth:`transformers.Wav2Vec2Tokenizer.__call__` for details.
|
||||
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
|
||||
.. warning::
|
||||
:obj:`attention_mask` should only be passed if the corresponding tokenizer has
|
||||
``config.return_attention_mask == True``. For all models whose tokenizer has
|
||||
``config.return_attention_mask == False``, such as `wav2vec2-base
|
||||
<https://huggingface.co/facebook/wav2vec2-base-960h>`__, :obj:`attention_mask` should **not** be passed
|
||||
to avoid degraded performance when doing batched inference. For such models :obj:`input_values` should
|
||||
simply be padded with 0 and passed without :obj:`attention_mask`. Be aware that these models also yield
|
||||
slightly different results depending on whether :obj:`input_values` is padded or not.
|
||||
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
|
@ -606,6 +671,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
|||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
|
@ -641,14 +707,33 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
|||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
hidden_states = self.feature_extractor(input_values)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute real output lengths according to convolution formula
|
||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
|
||||
# these two operations makes sure that all values
|
||||
# before the output lengths indices are attended to
|
||||
attention_mask[
|
||||
(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
|
||||
] = 1
|
||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||
|
||||
hidden_states = self.feature_projection(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
|
@ -681,6 +766,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
|||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
|
@ -755,6 +841,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
|||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
|
@ -795,6 +882,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
|||
|
||||
outputs = self.wav2vec2(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
|
@ -802,6 +890,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
|||
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
|
|
|
@ -87,6 +87,26 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||
The token used for padding, for example when batching sequences of different lengths.
|
||||
word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`):
|
||||
The token used for defining the end of a word.
|
||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to lowercase the output when decoding.
|
||||
do_normalize (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
|
||||
improve the performance for some models, *e.g.*, `wav2vec2-lv60
|
||||
<https://huggingface.co/models?search=lv60>`__.
|
||||
return_attention_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not :meth:`~transformers.Wav2Vec2Tokenizer.__call__` should return :obj:`attention_mask`.
|
||||
|
||||
.. note::
|
||||
|
||||
Wav2Vec2 models that have set ``config.feat_extract_norm == "group"``, such as `wav2vec2-base
|
||||
<https://huggingface.co/facebook/wav2vec2-base-960h>`__, have **not** been trained using
|
||||
:obj:`attention_mask`. For such models, :obj:`input_values` should simply be padded with 0 and no
|
||||
:obj:`attention_mask` should be passed.
|
||||
|
||||
For Wav2Vec2 models that have set ``config.feat_extract_norm == "layer"``, such as `wav2vec2-lv60
|
||||
<https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self>`__, :obj:`attention_mask` should be
|
||||
passed for batched inference.
|
||||
|
||||
**kwargs
|
||||
Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer`
|
||||
"""
|
||||
|
@ -100,7 +120,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json",
|
||||
},
|
||||
}
|
||||
model_input_names = ["input_values"]
|
||||
model_input_names = ["input_values", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -111,6 +131,8 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||
pad_token="<pad>",
|
||||
word_delimiter_token="|",
|
||||
do_lower_case=False,
|
||||
do_normalize=False,
|
||||
return_attention_mask=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -119,11 +141,16 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
do_lower_case=do_lower_case,
|
||||
do_normalize=do_normalize,
|
||||
return_attention_mask=return_attention_mask,
|
||||
word_delimiter_token=word_delimiter_token,
|
||||
**kwargs,
|
||||
)
|
||||
self._word_delimiter_token = word_delimiter_token
|
||||
|
||||
self.do_lower_case = do_lower_case
|
||||
self.return_attention_mask = return_attention_mask
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||
self.encoder = json.load(vocab_handle)
|
||||
|
@ -193,6 +220,10 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||
if not is_batched:
|
||||
raw_speech = [raw_speech]
|
||||
|
||||
# zero-mean and unit-variance normalization
|
||||
if self.do_normalize:
|
||||
raw_speech = [(x - np.mean(x)) / np.sqrt(np.var(x) + 1e-5) for x in raw_speech]
|
||||
|
||||
# convert into correct format for padding
|
||||
encoded_inputs = BatchEncoding({"input_values": raw_speech})
|
||||
|
||||
|
@ -201,7 +232,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||
padding=padding,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=False,
|
||||
return_attention_mask=self.return_attention_mask,
|
||||
return_tensors=return_tensors,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
import math
|
||||
import unittest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from tests.test_modeling_common import floats_tensor, random_attention_mask
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
||||
|
||||
|
@ -93,6 +93,7 @@ class Wav2Vec2ModelTester:
|
|||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
config = Wav2Vec2Config(
|
||||
hidden_size=self.hidden_size,
|
||||
|
@ -115,20 +116,48 @@ class Wav2Vec2ModelTester:
|
|||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
return config, input_values
|
||||
return config, input_values, attention_mask
|
||||
|
||||
def create_and_check_model(self, config, input_values):
|
||||
def create_and_check_model(self, config, input_values, attention_mask):
|
||||
model = Wav2Vec2Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_values)
|
||||
result = model(input_values, attention_mask=attention_mask)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_batch_inference(self, config, input_values, *args):
|
||||
# Not sure how to make this test pass at the moment. Batched input yields
|
||||
# same results as official fairseq implementation, but gives different results
|
||||
# depending on whether batched input is used or not
|
||||
# check: https://github.com/pytorch/fairseq/issues/3227
|
||||
model = Wav2Vec2Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0.0
|
||||
|
||||
batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state
|
||||
|
||||
for i in range(input_values.shape[0]):
|
||||
input_slice = input_values[i : i + 1, : input_lengths[i]]
|
||||
output = model(input_slice).last_hidden_state
|
||||
|
||||
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
|
||||
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, input_values = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"input_values": input_values}
|
||||
config, input_values, attention_mask = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
|
@ -222,6 +251,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_batched_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
@ -288,7 +321,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||
|
||||
return ds["speech"][:num_samples]
|
||||
|
||||
def test_inference_masked_lm_normal(self):
|
||||
def test_inference_ctc_normal(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model.to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
|
@ -306,16 +339,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_masked_lm_normal_batched(self):
|
||||
def test_inference_ctc_normal_batched(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model.to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to(
|
||||
torch_device
|
||||
)
|
||||
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values).logits
|
||||
|
@ -329,18 +362,19 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_masked_lm_robust_batched(self):
|
||||
def test_inference_ctc_robust_batched(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
|
||||
input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to(
|
||||
torch_device
|
||||
)
|
||||
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values).logits
|
||||
logits = model(input_values, attention_mask=attention_mask).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = tokenizer.batch_decode(predicted_ids)
|
||||
|
|
|
@ -23,7 +23,10 @@ import unittest
|
|||
|
||||
import numpy as np
|
||||
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2Tokenizer
|
||||
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2Tokenizer
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
@ -299,3 +302,46 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
|
|||
for parameter_name, parameter in signature.parameters.items():
|
||||
if parameter.default != inspect.Parameter.empty:
|
||||
self.assertIn(parameter_name, tokenizer.init_kwargs)
|
||||
|
||||
def test_zero_mean_unit_variance_normalization(self):
|
||||
tokenizer = self.get_tokenizer(do_normalize=True)
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
processed = tokenizer(speech_inputs, padding="longest")
|
||||
input_values = processed.input_values
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||
|
||||
_check_zero_mean_unit_variance(input_values[0, :800])
|
||||
_check_zero_mean_unit_variance(input_values[1, :1000])
|
||||
_check_zero_mean_unit_variance(input_values[2])
|
||||
|
||||
def test_return_attention_mask(self):
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
|
||||
# default case -> no attention_mask is returned
|
||||
tokenizer = self.get_tokenizer()
|
||||
processed = tokenizer(speech_inputs)
|
||||
self.assertNotIn("attention_mask", processed)
|
||||
|
||||
# wav2vec2-lv60 -> return attention_mask
|
||||
tokenizer = self.get_tokenizer(return_attention_mask=True)
|
||||
processed = tokenizer(speech_inputs, padding="longest")
|
||||
|
||||
self.assertIn("attention_mask", processed)
|
||||
self.assertListEqual(list(processed.attention_mask.shape), list(processed.input_values.shape))
|
||||
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), [800, 1000, 1200])
|
||||
|
||||
@slow
|
||||
def test_pretrained_checkpoints_are_set_correctly(self):
|
||||
# this test makes sure that models that are using
|
||||
# group norm don't have their tokenizer return the
|
||||
# attention_mask
|
||||
for model_id in WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST:
|
||||
config = Wav2Vec2Config.from_pretrained(model_id)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_id)
|
||||
|
||||
# only "layer" feature extraction norm should make use of
|
||||
# attention_mask
|
||||
self.assertEqual(tokenizer.return_attention_mask, config.feat_extract_norm == "layer")
|
||||
|
|
Loading…
Reference in New Issue