improve saving strategy of sentencepiece tokenizer (#15328)

* add new test

* add a feature to same the sentencepiece tokenizer model when the init file was deleted

* update marian

* update m2m_100

* fix marian

* update speech to text

* override test for layoutxlm

* fix saving bartpho

* remove harcoded values bartpho

* special token string version

* finish bartpho

* override layoutxml test

* add mbart

* move special tokens list

* format

* Revert "format"

This reverts commit 37a40df379.

* simplify list of string of special tokens

* Re-write `self.fairseq_tokens_to_ids ` initialization logic with special tokens

Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>

Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
SaulLu 2022-01-27 16:24:51 +01:00 committed by GitHub
parent 196cce6e9b
commit ade7371a41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 202 additions and 36 deletions

View File

@ -343,7 +343,11 @@ class AlbertTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -157,12 +157,20 @@ class BartphoTokenizer(PreTrainedTokenizer):
self.sp_model.Load(str(vocab_file)) self.sp_model.Load(str(vocab_file))
# Load the reduced vocab # Load the reduced vocab
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
# Keep order of special tokens for backward compatibility
self.fairseq_tokens_to_ids = {}
cnt = 0
for token in [bos_token, pad_token, eos_token, unk_token, sep_token, cls_token]:
if str(token) not in self.fairseq_tokens_to_ids:
self.fairseq_tokens_to_ids[str(token)] = cnt
cnt += 1
with open(monolingual_vocab_file, "r", encoding="utf-8") as f: with open(monolingual_vocab_file, "r", encoding="utf-8") as f:
for line in f.readlines(): for line in f.readlines():
token = line.strip().split()[0] token = line.strip().split()[0]
self.fairseq_tokens_to_ids[token] = len(self.fairseq_tokens_to_ids) self.fairseq_tokens_to_ids[token] = len(self.fairseq_tokens_to_ids)
self.fairseq_tokens_to_ids["<mask>"] = len(self.fairseq_tokens_to_ids) if str(mask_token) not in self.fairseq_tokens_to_ids:
self.fairseq_tokens_to_ids[str(mask_token)] = len(self.fairseq_tokens_to_ids)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
@ -278,7 +286,7 @@ class BartphoTokenizer(PreTrainedTokenizer):
if token in self.fairseq_tokens_to_ids: if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token] return self.fairseq_tokens_to_ids[token]
else: else:
return self.fairseq_tokens_to_ids["<unk>"] return self.unk_token_id
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
@ -301,10 +309,21 @@ class BartphoTokenizer(PreTrainedTokenizer):
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["monolingual_vocab_file"], (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["monolingual_vocab_file"],
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(out_monolingual_vocab_file): if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(
out_monolingual_vocab_file
) and os.path.isfile(self.monolingual_vocab_file):
copyfile(self.monolingual_vocab_file, out_monolingual_vocab_file) copyfile(self.monolingual_vocab_file, out_monolingual_vocab_file)
elif not os.path.isfile(self.monolingual_vocab_file):
with open(out_monolingual_vocab_file, "w", encoding="utf-8") as fp:
for token in self.fairseq_tokens_to_ids:
if token not in self.all_special_tokens:
fp.write(f"{str(token)} \n")
return out_vocab_file, out_monolingual_vocab_file return out_vocab_file, out_monolingual_vocab_file

View File

@ -160,7 +160,11 @@ class BertGenerationTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -189,8 +189,12 @@ class BigBirdTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -288,7 +288,11 @@ class CamembertTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -305,7 +305,11 @@ class FNetTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -331,8 +331,12 @@ class LayoutXLMTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for M2M100.""" """Tokenization classes for M2M100."""
import json import json
import os
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -312,8 +313,12 @@ class M2M100Tokenizer(PreTrainedTokenizer):
save_json(self.encoder, vocab_save_path) save_json(self.encoder, vocab_save_path)
if not spm_save_path.exists(): if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file):
copyfile(self.spm_file, spm_save_path) copyfile(self.spm_file, spm_save_path)
elif not os.path.isfile(self.spm_file):
with open(spm_save_path, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (str(vocab_save_path), str(spm_save_path)) return (str(vocab_save_path), str(spm_save_path))

View File

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import os
import re import re
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
@ -23,8 +23,11 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import sentencepiece import sentencepiece
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
"source_spm": "source.spm", "source_spm": "source.spm",
"target_spm": "target.spm", "target_spm": "target.spm",
@ -277,21 +280,35 @@ class MarianTokenizer(PreTrainedTokenizer):
return len(self.encoder) return len(self.encoder)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
save_dir = Path(save_directory) if not os.path.isdir(save_directory):
assert save_dir.is_dir(), f"{save_directory} should be a directory" logger.error(f"Vocabulary path ({save_directory}) should be a directory")
save_json( return
self.encoder, saved_files = []
save_dir / ((filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab"]), out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"]
) )
for orig, f in zip(["source.spm", "target.spm"], self.spm_files): save_json(self.encoder, out_vocab_file)
dest_path = save_dir / ((filename_prefix + "-" if filename_prefix else "") + Path(f).name) saved_files.append(out_vocab_file)
if not dest_path.exists():
copyfile(f, save_dir / orig)
return tuple( for spm_save_filename, spm_orig_path, spm_model in zip(
save_dir / ((filename_prefix + "-" if filename_prefix else "") + f) for f in self.vocab_files_names [VOCAB_FILES_NAMES["source_spm"], VOCAB_FILES_NAMES["target_spm"]],
) self.spm_files,
[self.spm_source, self.spm_target],
):
spm_save_path = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + spm_save_filename
)
if os.path.abspath(spm_orig_path) != os.path.abspath(spm_save_path) and os.path.isfile(spm_orig_path):
copyfile(spm_orig_path, spm_save_path)
saved_files.append(spm_save_path)
elif not os.path.isfile(spm_orig_path):
with open(spm_save_path, "wb") as fi:
content_spiece_model = spm_model.serialized_model_proto()
fi.write(content_spiece_model)
saved_files.append(spm_save_path)
return tuple(saved_files)
def get_vocab(self) -> Dict: def get_vocab(self) -> Dict:
vocab = self.encoder.copy() vocab = self.encoder.copy()

View File

@ -315,8 +315,12 @@ class MBartTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -245,8 +245,12 @@ class MBart50Tokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -285,7 +285,11 @@ class PegasusTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -167,7 +167,11 @@ class ReformerTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for Speech2Text.""" """Tokenization classes for Speech2Text."""
import json import json
import os
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@ -260,8 +260,12 @@ class Speech2TextTokenizer(PreTrainedTokenizer):
save_json(self.encoder, vocab_save_path) save_json(self.encoder, vocab_save_path)
if not spm_save_path.exists(): if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file):
copyfile(self.spm_file, spm_save_path) copyfile(self.spm_file, spm_save_path)
elif not os.path.isfile(self.spm_file):
with open(spm_save_path, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (str(vocab_save_path), str(spm_save_path)) return (str(vocab_save_path), str(spm_save_path))

View File

@ -303,8 +303,11 @@ class T5Tokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
logger.info(f"Copy vocab file to {out_vocab_file}") elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -302,8 +302,12 @@ class XLMProphetNetTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -310,7 +310,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -342,7 +342,11 @@ class XLNetTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
) )
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,) return (out_vocab_file,)

View File

@ -394,6 +394,33 @@ class TokenizerTesterMixin:
self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs) self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs)
self.check_subword_sampling(tokenizer_new) self.check_subword_sampling(tokenizer_new)
def test_save_sentencepiece_tokenizer(self) -> None:
if not self.test_sentencepiece or not self.test_slow_tokenizer:
return
# We want to verify that we will be able to save the tokenizer even if the original files that were used to
# build the tokenizer have been deleted in the meantime.
text = "This is text to test the tokenizer."
tokenizer_slow_1 = self.get_tokenizer()
encoding_tokenizer_slow_1 = tokenizer_slow_1(text)
tmpdirname_1 = tempfile.mkdtemp()
tmpdirname_2 = tempfile.mkdtemp()
tokenizer_slow_1.save_pretrained(tmpdirname_1)
tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1)
encoding_tokenizer_slow_2 = tokenizer_slow_2(text)
shutil.rmtree(tmpdirname_1)
tokenizer_slow_2.save_pretrained(tmpdirname_2)
tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2)
encoding_tokenizer_slow_3 = tokenizer_slow_3(text)
shutil.rmtree(tmpdirname_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
def test_model_input_names_signature(self): def test_model_input_names_signature(self):
accepted_model_main_input_names = [ accepted_model_main_input_names = [
"input_ids", # nlp models "input_ids", # nlp models

View File

@ -99,6 +99,44 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
output_text = "unwanted, running" output_text = "unwanted, running"
return input_text, output_text return input_text, output_text
# override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of
# this tokenizer
def test_save_sentencepiece_tokenizer(self) -> None:
if not self.test_sentencepiece or not self.test_slow_tokenizer:
return
# We want to verify that we will be able to save the tokenizer even if the original files that were used to
# build the tokenizer have been deleted in the meantime.
words, boxes = self.get_words_and_boxes()
tokenizer_slow_1 = self.get_tokenizer()
encoding_tokenizer_slow_1 = tokenizer_slow_1(
words,
boxes=boxes,
)
tmpdirname_1 = tempfile.mkdtemp()
tmpdirname_2 = tempfile.mkdtemp()
tokenizer_slow_1.save_pretrained(tmpdirname_1)
tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1)
encoding_tokenizer_slow_2 = tokenizer_slow_2(
words,
boxes=boxes,
)
shutil.rmtree(tmpdirname_1)
tokenizer_slow_2.save_pretrained(tmpdirname_2)
tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2)
encoding_tokenizer_slow_3 = tokenizer_slow_3(
words,
boxes=boxes,
)
shutil.rmtree(tmpdirname_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base") tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")

View File

@ -39,6 +39,7 @@ class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MBartTokenizer tokenizer_class = MBartTokenizer
rust_tokenizer_class = MBartTokenizerFast rust_tokenizer_class = MBartTokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
test_sentencepiece = True
def setUp(self): def setUp(self):
super().setUp() super().setUp()