Add sentencepiece to BertJapaneseTokenizer (#19769)

* support sentencepiece for bertjapanesetokenizer

* add test vocab file for sentencepiece, bertjapanesetokenizer

* make BasicTokenizer be identical to transformers.models.bert.tokenization_bert.BasicTokenizer

* fix missing of \n in comment

* fix init argument missing in tests

* make spm_file be optional, exclude spiece.model from tests/fixtures, and add description comments

* make comment length less than 119

* apply doc style check
This commit is contained in:
Hao Wang 2022-10-21 23:04:49 +09:00 committed by GitHub
parent 2ebf4e6a7b
commit 31565ff0fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 146 additions and 30 deletions

View File

@ -19,7 +19,9 @@ import collections
import copy
import os
import unicodedata
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from ...utils import logging
@ -27,7 +29,9 @@ from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "spm_file": "spiece.model"}
SPIECE_UNDERLINE = ""
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
@ -107,6 +111,9 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
Args:
vocab_file (`str`):
Path to a one-wordpiece-per-line vocabulary file.
spm_file (`str`, *optional*):
Path to [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm or .model
extension) that contains the vocabulary.
do_lower_case (`bool`, *optional*, defaults to `True`):
Whether to lower case the input. Only has an effect when do_basic_tokenize=True.
do_word_tokenize (`bool`, *optional*, defaults to `True`):
@ -116,7 +123,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
word_tokenizer_type (`str`, *optional*, defaults to `"basic"`):
Type of word tokenizer. Choose from ["basic", "mecab", "sudachi", "jumanpp"].
subword_tokenizer_type (`str`, *optional*, defaults to `"wordpiece"`):
Type of subword tokenizer. Choose from ["wordpiece", "character"].
Type of subword tokenizer. Choose from ["wordpiece", "character", "sentencepiece",].
mecab_kwargs (`dict`, *optional*):
Dictionary passed to the `MecabTokenizer` constructor.
sudachi_kwargs (`dict`, *optional*):
@ -133,6 +140,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
def __init__(
self,
vocab_file,
spm_file=None,
do_lower_case=False,
do_word_tokenize=True,
do_subword_tokenize=True,
@ -150,6 +158,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
**kwargs
):
super().__init__(
spm_file=spm_file,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
@ -167,13 +176,21 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
**kwargs,
)
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
" model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
if subword_tokenizer_type == "sentencepiece":
if not os.path.isfile(spm_file):
raise ValueError(
f"Can't find a vocabulary file at path '{spm_file}'. To load the vocabulary from a Google"
" pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.spm_file = spm_file
else:
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google"
" pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.do_word_tokenize = do_word_tokenize
self.word_tokenizer_type = word_tokenizer_type
@ -209,6 +226,8 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
elif subword_tokenizer_type == "character":
self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=self.unk_token)
elif subword_tokenizer_type == "sentencepiece":
self.subword_tokenizer = SentencepieceTokenizer(vocab=self.spm_file, unk_token=self.unk_token)
else:
raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.")
@ -251,27 +270,34 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
return split_tokens
@property
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size
def vocab_size(self):
if self.subword_tokenizer_type == "sentencepiece":
return len(self.subword_tokenizer.sp_model)
return len(self.vocab)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab
def get_vocab(self):
if self.subword_tokenizer_type == "sentencepiece":
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
return dict(self.vocab, **self.added_tokens_encoder)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
if self.subword_tokenizer_type == "sentencepiece":
return self.subword_tokenizer.sp_model.PieceToId(token)
return self.vocab.get(token, self.vocab.get(self.unk_token))
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if self.subword_tokenizer_type == "sentencepiece":
return self.subword_tokenizer.sp_model.IdToPiece(index)
return self.ids_to_tokens.get(index, self.unk_token)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
if self.subword_tokenizer_type == "sentencepiece":
return self.subword_tokenizer.sp_model.decode(tokens)
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
@ -360,25 +386,36 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if self.subword_tokenizer_type == "sentencepiece":
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["spm_file"]
)
else:
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"],
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!"
)
index = token_index
writer.write(token + "\n")
index += 1
if self.subword_tokenizer_type == "sentencepiece":
with open(vocab_file, "wb") as writer:
content_spiece_model = self.subword_tokenizer.sp_model.serialized_model_proto()
writer.write(content_spiece_model)
else:
with open(vocab_file, "w", encoding="utf-8") as writer:
index = 0
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!"
)
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,)
@ -893,3 +930,72 @@ class WordpieceTokenizer(object):
else:
output_tokens.extend(sub_tokens)
return output_tokens
class SentencepieceTokenizer(object):
"""
Runs sentencepiece tokenization. Based on transformers.models.albert.tokenization_albert.AlbertTokenizer.
"""
def __init__(
self,
vocab,
unk_token,
do_lower_case=False,
remove_space=True,
keep_accents=True,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
):
self.vocab = vocab
self.unk_token = unk_token
self.do_lower_case = do_lower_case
self.remove_space = remove_space
self.keep_accents = keep_accents
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab)
def preprocess_text(self, inputs):
if self.remove_space:
outputs = " ".join(inputs.strip().split())
else:
outputs = inputs
outputs = outputs.replace("``", '"').replace("''", '"')
if not self.keep_accents:
outputs = unicodedata.normalize("NFKD", outputs)
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
if self.do_lower_case:
outputs = outputs.lower()
return outputs
def tokenize(self, text):
"""
Tokenizes text by sentencepiece. Based on [SentencePiece](https://github.com/google/sentencepiece).
Tokenization needs the given vocabulary.
Args:
text: A string needs to be tokenized.
Returns:
A list of sentencepiece tokens.
"""
text = self.preprocess_text(text)
pieces = self.sp_model.encode(text, out_type=str)
new_pieces = []
for piece in pieces:
if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
else:
cur_pieces[0] = cur_pieces[0][1:]
cur_pieces.append(piece[-1])
new_pieces.extend(cur_pieces)
else:
new_pieces.append(piece)
return new_pieces

View File

@ -334,6 +334,16 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"])
def test_sentencepiece_tokenizer(self):
tokenizer = BertJapaneseTokenizer.from_pretrained("nlp-waseda/roberta-base-japanese-with-auto-jumanpp")
subword_tokenizer = tokenizer.subword_tokenizer
tokens = subword_tokenizer.tokenize("国境 の 長い トンネル を 抜ける と 雪国 であった 。")
self.assertListEqual(tokens, ["▁国境", "▁の", "▁長い", "▁トンネル", "▁を", "▁抜ける", "▁と", "▁雪", "", "▁であった", "▁。"])
tokens = subword_tokenizer.tokenize("こんばんは こんばん にち は こんにちは")
self.assertListEqual(tokens, ["▁こん", "ばん", "", "▁こん", "ばん", "▁に", "", "▁は", "▁こんにちは"])
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese")