# coding=utf-8 # Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import pickle import shutil import tempfile import unittest from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast from transformers.file_utils import cached_property from transformers.testing_utils import require_sentencepiece, require_tokenizers, slow from .test_tokenization_common import TokenizerTesterMixin SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") @require_sentencepiece @require_tokenizers class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = XLMRobertaTokenizer rust_tokenizer_class = XLMRobertaTokenizerFast test_rust_tokenizer = True test_sentencepiece = True def setUp(self): super().setUp() # We have a SentencePiece fixture for testing tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer.save_pretrained(self.tmpdirname) def test_convert_token_and_id(self): """Test ``_convert_token_to_id`` and ``_convert_id_to_token``.""" token = "" token_id = 1 self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id) self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token) def test_get_vocab(self): vocab_keys = list(self.get_tokenizer().get_vocab().keys()) self.assertEqual(vocab_keys[0], "") self.assertEqual(vocab_keys[1], "") self.assertEqual(vocab_keys[-1], "") self.assertEqual(len(vocab_keys), 1_002) def test_vocab_size(self): self.assertEqual(self.get_tokenizer().vocab_size, 1_002) def test_full_tokenizer(self): tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True) tokens = tokenizer.tokenize("This is a test") self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) self.assertListEqual( tokenizer.convert_tokens_to_ids(tokens), [value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]], ) tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") self.assertListEqual( tokens, [ SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "9", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "é", ".", ], ) ids = tokenizer.convert_tokens_to_ids(tokens) self.assertListEqual( ids, [ value + tokenizer.fairseq_offset for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4] # ^ unk: 2 + 1 = 3 unk: 2 + 1 = 3 ^ ], ) back_tokens = tokenizer.convert_ids_to_tokens(ids) self.assertListEqual( back_tokens, [ SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "", ".", ], ) @cached_property def big_tokenizer(self): return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base") def test_picklable_without_disk(self): with tempfile.NamedTemporaryFile() as f: shutil.copyfile(SAMPLE_VOCAB, f.name) tokenizer = XLMRobertaTokenizer(f.name, keep_accents=True) pickled_tokenizer = pickle.dumps(tokenizer) pickle.loads(pickled_tokenizer) def test_rust_and_python_full_tokenizers(self): if not self.test_rust_tokenizer: return tokenizer = self.get_tokenizer() rust_tokenizer = self.get_rust_tokenizer() sequence = "I was born in 92000, and this is falsé." tokens = tokenizer.tokenize(sequence) rust_tokens = rust_tokenizer.tokenize(sequence) self.assertListEqual(tokens, rust_tokens) ids = tokenizer.encode(sequence, add_special_tokens=False) rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False) self.assertListEqual(ids, rust_ids) rust_tokenizer = self.get_rust_tokenizer() ids = tokenizer.encode(sequence) rust_ids = rust_tokenizer.encode(sequence) self.assertListEqual(ids, rust_ids) @slow def test_tokenization_base_easy_symbols(self): symbols = "Hello World!" original_tokenizer_encodings = [0, 35378, 6661, 38, 2] # xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.base') # xlmr.large has same tokenizer # xlmr.eval() # xlmr.encode(symbols) self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols)) @slow def test_tokenization_base_hard_symbols(self): symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to , such as saoneuhaoesuth' original_tokenizer_encodings = [ 0, 3293, 83, 10, 4552, 4989, 7986, 678, 10, 5915, 111, 179459, 124850, 4, 6044, 237, 12, 6, 5, 6, 4, 6780, 705, 15, 1388, 44, 378, 10114, 711, 152, 20, 6, 5, 22376, 642, 1221, 15190, 34153, 450, 5608, 959, 1119, 57702, 136, 186, 47, 1098, 29367, 47, # 4426, # What fairseq tokenizes from "": "_<" # 3678, # What fairseq tokenizes from "": "unk" # 2740, # What fairseq tokenizes from "": ">" 3, # What we tokenize from "": "" 6, # Residue from the tokenization: an extra sentencepiece underline 4, 6044, 237, 6284, 50901, 528, 31, 90, 34, 927, 2, ] # xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.base') # xlmr.large has same tokenizer # xlmr.eval() # xlmr.encode(symbols) self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols)) @slow def test_tokenizer_integration(self): # fmt: off expected_encoding = {'input_ids': [[0, 11062, 82772, 7, 15, 82772, 538, 51529, 237, 17198, 1290, 206, 9, 215175, 1314, 136, 17198, 1290, 206, 9, 56359, 42, 122009, 9, 16466, 16, 87344, 4537, 9, 4717, 78381, 6, 159958, 7, 15, 24480, 618, 4, 527, 22693, 5428, 4, 2777, 24480, 9874, 4, 43523, 594, 4, 803, 18392, 33189, 18, 4, 43523, 24447, 12399, 100, 24955, 83658, 9626, 144057, 15, 839, 22335, 16, 136, 24955, 83658, 83479, 15, 39102, 724, 16, 678, 645, 2789, 1328, 4589, 42, 122009, 115774, 23, 805, 1328, 46876, 7, 136, 53894, 1940, 42227, 41159, 17721, 823, 425, 4, 27512, 98722, 206, 136, 5531, 4970, 919, 17336, 5, 2], [0, 20080, 618, 83, 82775, 47, 479, 9, 1517, 73, 53894, 333, 80581, 110117, 18811, 5256, 1295, 51, 152526, 297, 7986, 390, 124416, 538, 35431, 214, 98, 15044, 25737, 136, 7108, 43701, 23, 756, 135355, 7, 5, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 581, 63773, 119455, 6, 147797, 88203, 7, 645, 70, 21, 3285, 10269, 5, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E501 # fmt: on self.tokenizer_integration_test_util( expected_encoding=expected_encoding, model_name="xlm-roberta-base", revision="d9d8a8ea5eb94b1c6654ae9249df7793cd2933d3", )