[Speech2Text2] Enable tokenizers (#14390)

* [Speech2Text2] Enable tokenizers

* minor fix

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2021-11-15 16:34:11 +01:00 committed by GitHub
parent 267867e851
commit 4ce74edf51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 171 additions and 100 deletions

View File

@ -36,7 +36,7 @@ Tips:
- Speech2Text2 achieves state-of-the-art results on the CoVoST Speech Translation dataset. For more information, see
the `official models <https://huggingface.co/models?other=speech2text2>`__ .
- Speech2Text2 is always used within the :doc:`SpeechEncoderDecoder <speechencoderdecoder>` framework.
- Speech2Text2's tokenizer currently only supports inference, but not training.
- Speech2Text2's tokenizer is based on `fastBPE <https://github.com/glample/fastBPE>`.
Inference
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"tokenizer_config_file": "tokenizer_config.json",
"merges_file": "merges.txt",
}
PRETRAINED_VOCAB_FILES_MAP = {
@ -37,14 +38,33 @@ PRETRAINED_VOCAB_FILES_MAP = {
"tokenizer_config_file": {
"facebook/s2t-wav2vec2-large-en-de": "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/tokenizer_config.json",
},
"merges_file": {
"facebook/s2t-wav2vec2-large-en-de": "https://huggingface.co/facebook/s2t-wav2vec2-large-en-de/resolve/main/merges.txt",
},
}
BPE_TOKEN_MERGES = "</w>"
BPE_TOKEN_VOCAB = "@@ "
def get_pairs(word):
"""
Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
strings)
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
# Speech2Text2 has no max input length
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"facebook/s2t-wav2vec2-large-en-de": 1024}
class Speech2Text2Tokenizer(PreTrainedTokenizer):
"""
Constructs a Speech2Text2Tokenizer.
@ -73,19 +93,45 @@ class Speech2Text2Tokenizer(PreTrainedTokenizer):
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, vocab_file, bos_token="<s>", pad_token="<pad>", eos_token="</s>", unk_token="<unk>", **kwargs):
def __init__(
self,
vocab_file,
bos_token="<s>",
pad_token="<pad>",
eos_token="</s>",
unk_token="<unk>",
do_lower_case=False,
merges_file=None,
**kwargs
):
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
do_lower_case=do_lower_case,
**kwargs,
)
self.do_lower_case = do_lower_case
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
if merges_file is None:
logger.info(f"No merges files provided. {self.__class__.__name__} can only be used for decoding.")
self.bpe_ranks = None
self.cache = None
else:
with open(merges_file, encoding="utf-8") as merges_handle:
merges = merges_handle.read().split("\n")[:-1]
merges = [tuple(merge.split()[:2]) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
@property
def vocab_size(self) -> int:
return len(self.decoder)
@ -93,8 +139,77 @@ class Speech2Text2Tokenizer(PreTrainedTokenizer):
def get_vocab(self) -> Dict:
return dict(self.encoder, **self.added_tokens_encoder)
def _tokenize(self, text, **kwargs):
raise NotImplementedError("Tokenization requires a bpe tokenization file, which is currently not available")
def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + BPE_TOKEN_MERGES,)
if token in self.cache:
return self.cache[token]
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
if word == "\n " + BPE_TOKEN_MERGES:
word = "\n" + BPE_TOKEN_MERGES
if word.endswith(BPE_TOKEN_MERGES):
word = word.replace(BPE_TOKEN_MERGES, "")
word = word.replace(" ", BPE_TOKEN_VOCAB)
self.cache[token] = word
return word
def _tokenize(self, text):
"""Tokenize a string."""
if self.bpe_ranks is None:
raise ValueError(
"This tokenizer was instantiated without a `merges.txt` file, so"
" that it can only be used for decoding, not for encoding."
"Make sure to provide `merges.txt` file at instantiation to enable "
"encoding."
)
if self.do_lower_case:
text = text.lower()
text = text.split()
split_tokens = []
for token in text:
if token:
split_tokens.extend([t for t in self.bpe(token).split(" ")])
return split_tokens
def _convert_token_to_id(self, token: str) -> int:
"""Converts a token (str) in an index (integer) using the vocab."""
@ -113,7 +228,7 @@ class Speech2Text2Tokenizer(PreTrainedTokenizer):
string = " ".join(tokens)
# make sure @@ tokens are concatenated
string = "".join(string.split("@@ "))
string = "".join(string.split(BPE_TOKEN_VOCAB))
return string
@ -124,8 +239,26 @@ class Speech2Text2Tokenizer(PreTrainedTokenizer):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
merges_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
return (vocab_file,)
index = 0
if self.bpe_ranks is None:
return (vocab_file,)
with open(merges_file, "w", encoding="utf-8") as writer:
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merges_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return (vocab_file, merges_file)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import tempfile
@ -19,7 +20,6 @@ import unittest
from transformers.models.speech_to_text_2 import Speech2Text2Tokenizer
from transformers.models.speech_to_text_2.tokenization_speech_to_text_2 import VOCAB_FILES_NAMES
from transformers.testing_utils import is_pt_tf_cross_test
from .test_tokenization_common import TokenizerTesterMixin
@ -31,26 +31,32 @@ class SpeechToTextTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
def setUp(self):
super().setUp()
vocab = "<s> <pad> </s> <unk> here@@ a couple of@@ words for the vocab".split(" ")
vocab = "<s> <pad> </s> <unk> here@@ a couple of@@ words for the he@@ re@@ vocab".split(" ")
merges = ["he re</w> 123", "here a 1456"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
self.special_tokens_map = {"pad_token": "<pad>", "unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
self.tmpdirname = tempfile.mkdtemp()
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges))
def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
self.assertEqual(vocab_keys[0], "<s>")
self.assertEqual(vocab_keys[1], "<pad>")
self.assertEqual(vocab_keys[-1], "vocab")
self.assertEqual(len(vocab_keys), 12)
self.assertEqual(len(vocab_keys), 14)
def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 12)
self.assertEqual(self.get_tokenizer().vocab_size, 14)
def test_tokenizer_decode(self):
tokenizer = Speech2Text2Tokenizer.from_pretrained(self.tmpdirname)
@ -61,99 +67,31 @@ class SpeechToTextTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertTrue(output_string == "herecouple words ofthe")
# currently tokenizer cannot do encoding, but just decoding
def test_add_special_tokens(self):
pass
def test_load_no_merges_file(self):
tokenizer = Speech2Text2Tokenizer.from_pretrained(self.tmpdirname)
# currently tokenizer cannot do encoding, but just decoding
def test_add_tokens_tokenizer(self):
pass
with tempfile.TemporaryDirectory() as tmp_dirname:
tokenizer.save_pretrained(tmp_dirname)
os.remove(os.path.join(tmp_dirname, "merges.txt"))
# currently tokenizer cannot do encoding, but just decoding
def test_added_tokens_do_lower_case(self):
pass
# load tokenizer without merges file should not throw an error
tokenizer = Speech2Text2Tokenizer.from_pretrained(tmp_dirname)
# currently tokenizer cannot do encoding, but just decoding
def test_batch_encode_plus_batch_sequence_length(self):
pass
with tempfile.TemporaryDirectory() as tmp_dirname:
# save tokenizer and load again
tokenizer.save_pretrained(tmp_dirname)
tokenizer = Speech2Text2Tokenizer.from_pretrained(tmp_dirname)
# currently tokenizer cannot do encoding, but just decoding
def test_batch_encode_plus_overflowing_tokens(self):
pass
self.assertIsNotNone(tokenizer)
# currently tokenizer cannot do encoding, but just decoding
def test_batch_encode_plus_padding(self):
pass
# overwrite since merges_file is optional
def test_tokenizer_slow_store_full_signature(self):
if not self.test_slow_tokenizer:
return
# currently tokenizer cannot do encoding, but just decoding
def test_call(self):
pass
signature = inspect.signature(self.tokenizer_class.__init__)
tokenizer = self.get_tokenizer()
# currently tokenizer cannot do encoding, but just decoding
def test_encode_plus_with_padding(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_internal_consistency(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_maximum_encoding_length_pair_input(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_maximum_encoding_length_single_input(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_number_of_added_tokens(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_padding_to_max_length(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_padding_to_multiple_of(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_pickle_tokenizer(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_prepare_for_model(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_pretokenized_inputs(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_right_and_left_padding(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_save_and_load_tokenizer(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_special_tokens_mask(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_special_tokens_mask_input_pairs(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_token_type_ids(self):
pass
# currently tokenizer cannot do encoding, but just decoding
def test_added_token_are_matched_longest_first(self):
pass
# currently tokenizer cannot do encoding, but just decoding
@is_pt_tf_cross_test
def test_batch_encode_plus_tensors(self):
pass
for parameter_name, parameter in signature.parameters.items():
if parameter.default != inspect.Parameter.empty and parameter_name != "merges_file":
self.assertIn(parameter_name, tokenizer.init_kwargs)