improve saving strategy of sentencepiece tokenizer (#15328)

* add new test

* add a feature to same the sentencepiece tokenizer model when the init file was deleted

* update marian

* update m2m_100

* fix marian

* update speech to text

* override test for layoutxlm

* fix saving bartpho

* remove harcoded values bartpho

* special token string version

* finish bartpho

* override layoutxml test

* add mbart

* move special tokens list

* format

* Revert "format"

This reverts commit 37a40df379.

* simplify list of string of special tokens

* Re-write `self.fairseq_tokens_to_ids ` initialization logic with special tokens

Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>

Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
SaulLu 2022-01-27 16:24:51 +01:00 committed by GitHub
parent 196cce6e9b
commit ade7371a41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 202 additions and 36 deletions

View File

@ -343,7 +343,11 @@ class AlbertTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -157,12 +157,20 @@ class BartphoTokenizer(PreTrainedTokenizer):
self.sp_model.Load(str(vocab_file))
# Load the reduced vocab
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
# Keep order of special tokens for backward compatibility
self.fairseq_tokens_to_ids = {}
cnt = 0
for token in [bos_token, pad_token, eos_token, unk_token, sep_token, cls_token]:
if str(token) not in self.fairseq_tokens_to_ids:
self.fairseq_tokens_to_ids[str(token)] = cnt
cnt += 1
with open(monolingual_vocab_file, "r", encoding="utf-8") as f:
for line in f.readlines():
token = line.strip().split()[0]
self.fairseq_tokens_to_ids[token] = len(self.fairseq_tokens_to_ids)
self.fairseq_tokens_to_ids["<mask>"] = len(self.fairseq_tokens_to_ids)
if str(mask_token) not in self.fairseq_tokens_to_ids:
self.fairseq_tokens_to_ids[str(mask_token)] = len(self.fairseq_tokens_to_ids)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
@ -278,7 +286,7 @@ class BartphoTokenizer(PreTrainedTokenizer):
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
else:
return self.fairseq_tokens_to_ids["<unk>"]
return self.unk_token_id
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
@ -301,10 +309,21 @@ class BartphoTokenizer(PreTrainedTokenizer):
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["monolingual_vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(out_monolingual_vocab_file):
if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(
out_monolingual_vocab_file
) and os.path.isfile(self.monolingual_vocab_file):
copyfile(self.monolingual_vocab_file, out_monolingual_vocab_file)
elif not os.path.isfile(self.monolingual_vocab_file):
with open(out_monolingual_vocab_file, "w", encoding="utf-8") as fp:
for token in self.fairseq_tokens_to_ids:
if token not in self.all_special_tokens:
fp.write(f"{str(token)} \n")
return out_vocab_file, out_monolingual_vocab_file

View File

@ -160,7 +160,11 @@ class BertGenerationTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -189,8 +189,12 @@ class BigBirdTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -288,7 +288,11 @@ class CamembertTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -305,7 +305,11 @@ class FNetTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -331,8 +331,12 @@ class LayoutXLMTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Tokenization classes for M2M100."""
import json
import os
from contextlib import contextmanager
from pathlib import Path
from shutil import copyfile
@ -312,8 +313,12 @@ class M2M100Tokenizer(PreTrainedTokenizer):
save_json(self.encoder, vocab_save_path)
if not spm_save_path.exists():
if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file):
copyfile(self.spm_file, spm_save_path)
elif not os.path.isfile(self.spm_file):
with open(spm_save_path, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (str(vocab_save_path), str(spm_save_path))

View File

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import re
import warnings
from contextlib import contextmanager
@ -23,8 +23,11 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import sentencepiece
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"source_spm": "source.spm",
"target_spm": "target.spm",
@ -277,21 +280,35 @@ class MarianTokenizer(PreTrainedTokenizer):
return len(self.encoder)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
save_dir = Path(save_directory)
assert save_dir.is_dir(), f"{save_directory} should be a directory"
save_json(
self.encoder,
save_dir / ((filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab"]),
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
saved_files = []
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"]
)
for orig, f in zip(["source.spm", "target.spm"], self.spm_files):
dest_path = save_dir / ((filename_prefix + "-" if filename_prefix else "") + Path(f).name)
if not dest_path.exists():
copyfile(f, save_dir / orig)
save_json(self.encoder, out_vocab_file)
saved_files.append(out_vocab_file)
return tuple(
save_dir / ((filename_prefix + "-" if filename_prefix else "") + f) for f in self.vocab_files_names
)
for spm_save_filename, spm_orig_path, spm_model in zip(
[VOCAB_FILES_NAMES["source_spm"], VOCAB_FILES_NAMES["target_spm"]],
self.spm_files,
[self.spm_source, self.spm_target],
):
spm_save_path = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + spm_save_filename
)
if os.path.abspath(spm_orig_path) != os.path.abspath(spm_save_path) and os.path.isfile(spm_orig_path):
copyfile(spm_orig_path, spm_save_path)
saved_files.append(spm_save_path)
elif not os.path.isfile(spm_orig_path):
with open(spm_save_path, "wb") as fi:
content_spiece_model = spm_model.serialized_model_proto()
fi.write(content_spiece_model)
saved_files.append(spm_save_path)
return tuple(saved_files)
def get_vocab(self) -> Dict:
vocab = self.encoder.copy()

View File

@ -315,8 +315,12 @@ class MBartTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -245,8 +245,12 @@ class MBart50Tokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -285,7 +285,11 @@ class PegasusTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -167,7 +167,11 @@ class ReformerTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Speech2Text."""
import json
import os
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
@ -260,8 +260,12 @@ class Speech2TextTokenizer(PreTrainedTokenizer):
save_json(self.encoder, vocab_save_path)
if not spm_save_path.exists():
if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file):
copyfile(self.spm_file, spm_save_path)
elif not os.path.isfile(self.spm_file):
with open(spm_save_path, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (str(vocab_save_path), str(spm_save_path))

View File

@ -303,8 +303,11 @@ class T5Tokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
logger.info(f"Copy vocab file to {out_vocab_file}")
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -302,8 +302,12 @@ class XLMProphetNetTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -310,7 +310,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -342,7 +342,11 @@ class XLNetTokenizer(PreTrainedTokenizer):
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)

View File

@ -394,6 +394,33 @@ class TokenizerTesterMixin:
self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs)
self.check_subword_sampling(tokenizer_new)
def test_save_sentencepiece_tokenizer(self) -> None:
if not self.test_sentencepiece or not self.test_slow_tokenizer:
return
# We want to verify that we will be able to save the tokenizer even if the original files that were used to
# build the tokenizer have been deleted in the meantime.
text = "This is text to test the tokenizer."
tokenizer_slow_1 = self.get_tokenizer()
encoding_tokenizer_slow_1 = tokenizer_slow_1(text)
tmpdirname_1 = tempfile.mkdtemp()
tmpdirname_2 = tempfile.mkdtemp()
tokenizer_slow_1.save_pretrained(tmpdirname_1)
tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1)
encoding_tokenizer_slow_2 = tokenizer_slow_2(text)
shutil.rmtree(tmpdirname_1)
tokenizer_slow_2.save_pretrained(tmpdirname_2)
tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2)
encoding_tokenizer_slow_3 = tokenizer_slow_3(text)
shutil.rmtree(tmpdirname_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
def test_model_input_names_signature(self):
accepted_model_main_input_names = [
"input_ids", # nlp models

View File

@ -99,6 +99,44 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
output_text = "unwanted, running"
return input_text, output_text
# override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of
# this tokenizer
def test_save_sentencepiece_tokenizer(self) -> None:
if not self.test_sentencepiece or not self.test_slow_tokenizer:
return
# We want to verify that we will be able to save the tokenizer even if the original files that were used to
# build the tokenizer have been deleted in the meantime.
words, boxes = self.get_words_and_boxes()
tokenizer_slow_1 = self.get_tokenizer()
encoding_tokenizer_slow_1 = tokenizer_slow_1(
words,
boxes=boxes,
)
tmpdirname_1 = tempfile.mkdtemp()
tmpdirname_2 = tempfile.mkdtemp()
tokenizer_slow_1.save_pretrained(tmpdirname_1)
tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1)
encoding_tokenizer_slow_2 = tokenizer_slow_2(
words,
boxes=boxes,
)
shutil.rmtree(tmpdirname_1)
tokenizer_slow_2.save_pretrained(tmpdirname_2)
tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2)
encoding_tokenizer_slow_3 = tokenizer_slow_3(
words,
boxes=boxes,
)
shutil.rmtree(tmpdirname_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
@slow
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")

View File

@ -39,6 +39,7 @@ class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MBartTokenizer
rust_tokenizer_class = MBartTokenizerFast
test_rust_tokenizer = True
test_sentencepiece = True
def setUp(self):
super().setUp()