[Wav2Vec2] Improve Tokenizer & Model for batched inference (#10117)

* save intermediate

* finish batch the same as fairseq

* add normalization

* fix batched input

* add better comment

* Update src/transformers/models/wav2vec2/modeling_wav2vec2.py

* add nice docstring

* add tokenizer tests

* make all slow tests pass

* finish PR

* correct import
This commit is contained in:
Patrick von Platen 2021-02-11 15:40:54 +03:00 committed by GitHub
parent 2f3b5f4dcc
commit 495c157d6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 227 additions and 27 deletions

View File

@ -36,7 +36,10 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Wav2Vec2Config"
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/wav2vec2-base-960h"
"facebook/wav2vec2-base-960h",
"facebook/wav2vec2-large-960h",
"facebook/wav2vec2-large-960h-lv60",
"facebook/wav2vec2-large-960h-lv60-self",
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
]
@ -191,7 +194,6 @@ class Wav2Vec2FeatureProjection(nn.Module):
self.dropout = nn.Dropout(config.feat_extract_dropout)
def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(hidden_states)
hidden_states = self.dropout(hidden_states)
@ -387,9 +389,11 @@ class Wav2Vec2EncoderLayer(nn.Module):
self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, output_attentions=False):
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
attn_residual = hidden_states
hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
@ -414,10 +418,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, output_attentions=False):
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
@ -438,6 +444,7 @@ class Wav2Vec2Encoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
@ -445,6 +452,16 @@ class Wav2Vec2Encoder(nn.Module):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.layer_norm(hidden_states)
@ -454,7 +471,9 @@ class Wav2Vec2Encoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions)
hidden_states, attn_weights = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
if output_attentions:
all_self_attentions = all_self_attentions + (attn_weights,)
@ -486,6 +505,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
@ -493,6 +513,16 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
# make sure padded tokens are not attended to
hidden_states[~attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states)
@ -501,7 +531,9 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions)
hidden_states, attn_weights = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
if output_attentions:
all_self_attentions = all_self_attentions + (attn_weights,)
@ -544,6 +576,21 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch.floor((input_length - kernel_size) / stride + 1)
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths.to(torch.long)
WAV_2_VEC_2_START_DOCSTRING = r"""
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
@ -572,6 +619,24 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
soundfile`). To prepare the array into `input_values`, the :class:`~transformers.Wav2Vec2Tokenizer` should
be used for padding and conversion into a tensor of type `torch.FloatTensor`. See
:meth:`transformers.Wav2Vec2Tokenizer.__call__` for details.
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in ``[0,
1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
.. warning::
:obj:`attention_mask` should only be passed if the corresponding tokenizer has
``config.return_attention_mask == True``. For all models whose tokenizer has
``config.return_attention_mask == False``, such as `wav2vec2-base
<https://huggingface.co/facebook/wav2vec2-base-960h>`__, :obj:`attention_mask` should **not** be passed
to avoid degraded performance when doing batched inference. For such models :obj:`input_values` should
simply be padded with 0 and passed without :obj:`attention_mask`. Be aware that these models also yield
slightly different results depending on whether :obj:`input_values` is padded or not.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
@ -606,6 +671,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@ -641,14 +707,33 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.feature_extractor(input_values)
hidden_states = hidden_states.transpose(1, 2)
if attention_mask is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
attention_mask = torch.zeros(
hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask[
(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
hidden_states = self.feature_projection(hidden_states)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if not return_dict:
@ -681,6 +766,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@ -755,6 +841,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@ -795,6 +882,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
@ -802,6 +890,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.lm_head(hidden_states)
if not return_dict:

View File

@ -87,6 +87,26 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
The token used for padding, for example when batching sequences of different lengths.
word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`):
The token used for defining the end of a word.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to lowercase the output when decoding.
do_normalize (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
improve the performance for some models, *e.g.*, `wav2vec2-lv60
<https://huggingface.co/models?search=lv60>`__.
return_attention_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not :meth:`~transformers.Wav2Vec2Tokenizer.__call__` should return :obj:`attention_mask`.
.. note::
Wav2Vec2 models that have set ``config.feat_extract_norm == "group"``, such as `wav2vec2-base
<https://huggingface.co/facebook/wav2vec2-base-960h>`__, have **not** been trained using
:obj:`attention_mask`. For such models, :obj:`input_values` should simply be padded with 0 and no
:obj:`attention_mask` should be passed.
For Wav2Vec2 models that have set ``config.feat_extract_norm == "layer"``, such as `wav2vec2-lv60
<https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self>`__, :obj:`attention_mask` should be
passed for batched inference.
**kwargs
Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer`
"""
@ -100,7 +120,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json",
},
}
model_input_names = ["input_values"]
model_input_names = ["input_values", "attention_mask"]
def __init__(
self,
@ -111,6 +131,8 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
pad_token="<pad>",
word_delimiter_token="|",
do_lower_case=False,
do_normalize=False,
return_attention_mask=False,
**kwargs
):
super().__init__(
@ -119,11 +141,16 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
eos_token=eos_token,
pad_token=pad_token,
do_lower_case=do_lower_case,
do_normalize=do_normalize,
return_attention_mask=return_attention_mask,
word_delimiter_token=word_delimiter_token,
**kwargs,
)
self._word_delimiter_token = word_delimiter_token
self.do_lower_case = do_lower_case
self.return_attention_mask = return_attention_mask
self.do_normalize = do_normalize
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
@ -193,6 +220,10 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
if not is_batched:
raw_speech = [raw_speech]
# zero-mean and unit-variance normalization
if self.do_normalize:
raw_speech = [(x - np.mean(x)) / np.sqrt(np.var(x) + 1e-5) for x in raw_speech]
# convert into correct format for padding
encoded_inputs = BatchEncoding({"input_values": raw_speech})
@ -201,7 +232,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
padding=padding,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=False,
return_attention_mask=self.return_attention_mask,
return_tensors=return_tensors,
verbose=verbose,
)

View File

@ -18,7 +18,7 @@
import math
import unittest
from tests.test_modeling_common import floats_tensor
from tests.test_modeling_common import floats_tensor, random_attention_mask
from transformers import is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
@ -93,6 +93,7 @@ class Wav2Vec2ModelTester:
def prepare_config_and_inputs(self):
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = Wav2Vec2Config(
hidden_size=self.hidden_size,
@ -115,20 +116,48 @@ class Wav2Vec2ModelTester:
vocab_size=self.vocab_size,
)
return config, input_values
return config, input_values, attention_mask
def create_and_check_model(self, config, input_values):
def create_and_check_model(self, config, input_values, attention_mask):
model = Wav2Vec2Model(config=config)
model.to(torch_device)
model.eval()
result = model(input_values)
result = model(input_values, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
)
def create_and_check_batch_inference(self, config, input_values, *args):
# Not sure how to make this test pass at the moment. Batched input yields
# same results as official fairseq implementation, but gives different results
# depending on whether batched input is used or not
# check: https://github.com/pytorch/fairseq/issues/3227
model = Wav2Vec2Model(config=config)
model.to(torch_device)
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0.0
batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state
for i in range(input_values.shape[0]):
input_slice = input_values[i : i + 1, : input_lengths[i]]
output = model(input_slice).last_hidden_state
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
def prepare_config_and_inputs_for_common(self):
config, input_values = self.prepare_config_and_inputs()
inputs_dict = {"input_values": input_values}
config, input_values, attention_mask = self.prepare_config_and_inputs()
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
return config, inputs_dict
@ -222,6 +251,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_batched_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds
def test_inputs_embeds(self):
pass
@ -288,7 +321,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
return ds["speech"][:num_samples]
def test_inference_masked_lm_normal(self):
def test_inference_ctc_normal(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
@ -306,16 +339,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_masked_lm_normal_batched(self):
def test_inference_ctc_normal_batched(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
input_speech = self._load_datasamples(2)
input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to(
torch_device
)
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
input_values = inputs.input_values.to(torch_device)
with torch.no_grad():
logits = model(input_values).logits
@ -329,18 +362,19 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_masked_lm_robust_batched(self):
def test_inference_ctc_robust_batched(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
input_speech = self._load_datasamples(4)
input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to(
torch_device
)
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
with torch.no_grad():
logits = model(input_values).logits
logits = model(input_values, attention_mask=attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = tokenizer.batch_decode(predicted_ids)

View File

@ -23,7 +23,10 @@ import unittest
import numpy as np
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2Tokenizer
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2Tokenizer
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
from transformers.testing_utils import slow
global_rng = random.Random()
@ -299,3 +302,46 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
for parameter_name, parameter in signature.parameters.items():
if parameter.default != inspect.Parameter.empty:
self.assertIn(parameter_name, tokenizer.init_kwargs)
def test_zero_mean_unit_variance_normalization(self):
tokenizer = self.get_tokenizer(do_normalize=True)
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = tokenizer(speech_inputs, padding="longest")
input_values = processed.input_values
def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
_check_zero_mean_unit_variance(input_values[0, :800])
_check_zero_mean_unit_variance(input_values[1, :1000])
_check_zero_mean_unit_variance(input_values[2])
def test_return_attention_mask(self):
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
# default case -> no attention_mask is returned
tokenizer = self.get_tokenizer()
processed = tokenizer(speech_inputs)
self.assertNotIn("attention_mask", processed)
# wav2vec2-lv60 -> return attention_mask
tokenizer = self.get_tokenizer(return_attention_mask=True)
processed = tokenizer(speech_inputs, padding="longest")
self.assertIn("attention_mask", processed)
self.assertListEqual(list(processed.attention_mask.shape), list(processed.input_values.shape))
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), [800, 1000, 1200])
@slow
def test_pretrained_checkpoints_are_set_correctly(self):
# this test makes sure that models that are using
# group norm don't have their tokenizer return the
# attention_mask
for model_id in WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST:
config = Wav2Vec2Config.from_pretrained(model_id)
tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_id)
# only "layer" feature extraction norm should make use of
# attention_mask
self.assertEqual(tokenizer.return_attention_mask, config.feat_extract_norm == "layer")