improve saving strategy of sentencepiece tokenizer (#15328)
* add new test
* add a feature to same the sentencepiece tokenizer model when the init file was deleted
* update marian
* update m2m_100
* fix marian
* update speech to text
* override test for layoutxlm
* fix saving bartpho
* remove harcoded values bartpho
* special token string version
* finish bartpho
* override layoutxml test
* add mbart
* move special tokens list
* format
* Revert "format"
This reverts commit 37a40df379
.
* simplify list of string of special tokens
* Re-write `self.fairseq_tokens_to_ids ` initialization logic with special tokens
Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
parent
196cce6e9b
commit
ade7371a41
|
@ -343,7 +343,11 @@ class AlbertTokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
|
|
@ -157,12 +157,20 @@ class BartphoTokenizer(PreTrainedTokenizer):
|
|||
self.sp_model.Load(str(vocab_file))
|
||||
|
||||
# Load the reduced vocab
|
||||
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
|
||||
|
||||
# Keep order of special tokens for backward compatibility
|
||||
self.fairseq_tokens_to_ids = {}
|
||||
cnt = 0
|
||||
for token in [bos_token, pad_token, eos_token, unk_token, sep_token, cls_token]:
|
||||
if str(token) not in self.fairseq_tokens_to_ids:
|
||||
self.fairseq_tokens_to_ids[str(token)] = cnt
|
||||
cnt += 1
|
||||
with open(monolingual_vocab_file, "r", encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
token = line.strip().split()[0]
|
||||
self.fairseq_tokens_to_ids[token] = len(self.fairseq_tokens_to_ids)
|
||||
self.fairseq_tokens_to_ids["<mask>"] = len(self.fairseq_tokens_to_ids)
|
||||
if str(mask_token) not in self.fairseq_tokens_to_ids:
|
||||
self.fairseq_tokens_to_ids[str(mask_token)] = len(self.fairseq_tokens_to_ids)
|
||||
|
||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||
|
||||
|
@ -278,7 +286,7 @@ class BartphoTokenizer(PreTrainedTokenizer):
|
|||
if token in self.fairseq_tokens_to_ids:
|
||||
return self.fairseq_tokens_to_ids[token]
|
||||
else:
|
||||
return self.fairseq_tokens_to_ids["<unk>"]
|
||||
return self.unk_token_id
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
|
@ -301,10 +309,21 @@ class BartphoTokenizer(PreTrainedTokenizer):
|
|||
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["monolingual_vocab_file"],
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(out_monolingual_vocab_file):
|
||||
if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(
|
||||
out_monolingual_vocab_file
|
||||
) and os.path.isfile(self.monolingual_vocab_file):
|
||||
copyfile(self.monolingual_vocab_file, out_monolingual_vocab_file)
|
||||
elif not os.path.isfile(self.monolingual_vocab_file):
|
||||
with open(out_monolingual_vocab_file, "w", encoding="utf-8") as fp:
|
||||
for token in self.fairseq_tokens_to_ids:
|
||||
if token not in self.all_special_tokens:
|
||||
fp.write(f"{str(token)} \n")
|
||||
|
||||
return out_vocab_file, out_monolingual_vocab_file
|
||||
|
|
|
@ -160,7 +160,11 @@ class BertGenerationTokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
|
|
@ -189,8 +189,12 @@ class BigBirdTokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
|
|
|
@ -288,7 +288,11 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
|
|
@ -305,7 +305,11 @@ class FNetTokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
|
|
@ -331,8 +331,12 @@ class LayoutXLMTokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
"""Tokenization classes for M2M100."""
|
||||
import json
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
|
@ -312,8 +313,12 @@ class M2M100Tokenizer(PreTrainedTokenizer):
|
|||
|
||||
save_json(self.encoder, vocab_save_path)
|
||||
|
||||
if not spm_save_path.exists():
|
||||
if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file):
|
||||
copyfile(self.spm_file, spm_save_path)
|
||||
elif not os.path.isfile(self.spm_file):
|
||||
with open(spm_save_path, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (str(vocab_save_path), str(spm_save_path))
|
||||
|
||||
|
|
|
@ -11,8 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
@ -23,8 +23,11 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||
import sentencepiece
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"source_spm": "source.spm",
|
||||
"target_spm": "target.spm",
|
||||
|
@ -277,21 +280,35 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||
return len(self.encoder)
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
save_dir = Path(save_directory)
|
||||
assert save_dir.is_dir(), f"{save_directory} should be a directory"
|
||||
save_json(
|
||||
self.encoder,
|
||||
save_dir / ((filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab"]),
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
saved_files = []
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"]
|
||||
)
|
||||
|
||||
for orig, f in zip(["source.spm", "target.spm"], self.spm_files):
|
||||
dest_path = save_dir / ((filename_prefix + "-" if filename_prefix else "") + Path(f).name)
|
||||
if not dest_path.exists():
|
||||
copyfile(f, save_dir / orig)
|
||||
save_json(self.encoder, out_vocab_file)
|
||||
saved_files.append(out_vocab_file)
|
||||
|
||||
return tuple(
|
||||
save_dir / ((filename_prefix + "-" if filename_prefix else "") + f) for f in self.vocab_files_names
|
||||
)
|
||||
for spm_save_filename, spm_orig_path, spm_model in zip(
|
||||
[VOCAB_FILES_NAMES["source_spm"], VOCAB_FILES_NAMES["target_spm"]],
|
||||
self.spm_files,
|
||||
[self.spm_source, self.spm_target],
|
||||
):
|
||||
spm_save_path = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + spm_save_filename
|
||||
)
|
||||
if os.path.abspath(spm_orig_path) != os.path.abspath(spm_save_path) and os.path.isfile(spm_orig_path):
|
||||
copyfile(spm_orig_path, spm_save_path)
|
||||
saved_files.append(spm_save_path)
|
||||
elif not os.path.isfile(spm_orig_path):
|
||||
with open(spm_save_path, "wb") as fi:
|
||||
content_spiece_model = spm_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
saved_files.append(spm_save_path)
|
||||
|
||||
return tuple(saved_files)
|
||||
|
||||
def get_vocab(self) -> Dict:
|
||||
vocab = self.encoder.copy()
|
||||
|
|
|
@ -315,8 +315,12 @@ class MBartTokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
|
|
|
@ -245,8 +245,12 @@ class MBart50Tokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
|
|
|
@ -285,7 +285,11 @@ class PegasusTokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
|
|
@ -167,7 +167,11 @@ class ReformerTokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for Speech2Text."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
@ -260,8 +260,12 @@ class Speech2TextTokenizer(PreTrainedTokenizer):
|
|||
|
||||
save_json(self.encoder, vocab_save_path)
|
||||
|
||||
if not spm_save_path.exists():
|
||||
if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file):
|
||||
copyfile(self.spm_file, spm_save_path)
|
||||
elif not os.path.isfile(self.spm_file):
|
||||
with open(spm_save_path, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (str(vocab_save_path), str(spm_save_path))
|
||||
|
||||
|
|
|
@ -303,8 +303,11 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
logger.info(f"Copy vocab file to {out_vocab_file}")
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
|
|
@ -302,8 +302,12 @@ class XLMProphetNetTokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
|
|
|
@ -310,7 +310,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
|
|
@ -342,7 +342,11 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
|
|
@ -394,6 +394,33 @@ class TokenizerTesterMixin:
|
|||
self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs)
|
||||
self.check_subword_sampling(tokenizer_new)
|
||||
|
||||
def test_save_sentencepiece_tokenizer(self) -> None:
|
||||
if not self.test_sentencepiece or not self.test_slow_tokenizer:
|
||||
return
|
||||
# We want to verify that we will be able to save the tokenizer even if the original files that were used to
|
||||
# build the tokenizer have been deleted in the meantime.
|
||||
text = "This is text to test the tokenizer."
|
||||
|
||||
tokenizer_slow_1 = self.get_tokenizer()
|
||||
encoding_tokenizer_slow_1 = tokenizer_slow_1(text)
|
||||
|
||||
tmpdirname_1 = tempfile.mkdtemp()
|
||||
tmpdirname_2 = tempfile.mkdtemp()
|
||||
|
||||
tokenizer_slow_1.save_pretrained(tmpdirname_1)
|
||||
tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1)
|
||||
encoding_tokenizer_slow_2 = tokenizer_slow_2(text)
|
||||
|
||||
shutil.rmtree(tmpdirname_1)
|
||||
tokenizer_slow_2.save_pretrained(tmpdirname_2)
|
||||
|
||||
tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2)
|
||||
encoding_tokenizer_slow_3 = tokenizer_slow_3(text)
|
||||
shutil.rmtree(tmpdirname_2)
|
||||
|
||||
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
|
||||
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
|
||||
|
||||
def test_model_input_names_signature(self):
|
||||
accepted_model_main_input_names = [
|
||||
"input_ids", # nlp models
|
||||
|
|
|
@ -99,6 +99,44 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
output_text = "unwanted, running"
|
||||
return input_text, output_text
|
||||
|
||||
# override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of
|
||||
# this tokenizer
|
||||
def test_save_sentencepiece_tokenizer(self) -> None:
|
||||
if not self.test_sentencepiece or not self.test_slow_tokenizer:
|
||||
return
|
||||
# We want to verify that we will be able to save the tokenizer even if the original files that were used to
|
||||
# build the tokenizer have been deleted in the meantime.
|
||||
words, boxes = self.get_words_and_boxes()
|
||||
|
||||
tokenizer_slow_1 = self.get_tokenizer()
|
||||
encoding_tokenizer_slow_1 = tokenizer_slow_1(
|
||||
words,
|
||||
boxes=boxes,
|
||||
)
|
||||
|
||||
tmpdirname_1 = tempfile.mkdtemp()
|
||||
tmpdirname_2 = tempfile.mkdtemp()
|
||||
|
||||
tokenizer_slow_1.save_pretrained(tmpdirname_1)
|
||||
tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1)
|
||||
encoding_tokenizer_slow_2 = tokenizer_slow_2(
|
||||
words,
|
||||
boxes=boxes,
|
||||
)
|
||||
|
||||
shutil.rmtree(tmpdirname_1)
|
||||
tokenizer_slow_2.save_pretrained(tmpdirname_2)
|
||||
|
||||
tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2)
|
||||
encoding_tokenizer_slow_3 = tokenizer_slow_3(
|
||||
words,
|
||||
boxes=boxes,
|
||||
)
|
||||
shutil.rmtree(tmpdirname_2)
|
||||
|
||||
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
|
||||
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
|
||||
|
||||
@slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")
|
||||
|
|
|
@ -39,6 +39,7 @@ class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
tokenizer_class = MBartTokenizer
|
||||
rust_tokenizer_class = MBartTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
test_sentencepiece = True
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
|
Loading…
Reference in New Issue