Add sentencepiece to BertJapaneseTokenizer (#19769)

* support sentencepiece for bertjapanesetokenizer

* add test vocab file for sentencepiece, bertjapanesetokenizer

* make BasicTokenizer be identical to transformers.models.bert.tokenization_bert.BasicTokenizer

* fix missing of \n in comment

* fix init argument missing in tests

* make spm_file be optional, exclude spiece.model from tests/fixtures, and add description comments

* make comment length less than 119

* apply doc style check
This commit is contained in:
Hao Wang 2022-10-21 23:04:49 +09:00 committed by GitHub
parent 2ebf4e6a7b
commit 31565ff0fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 146 additions and 30 deletions

View File

@ -19,7 +19,9 @@ import collections
import copy import copy
import os import os
import unicodedata import unicodedata
from typing import List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from ...utils import logging from ...utils import logging
@ -27,7 +29,9 @@ from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "spm_file": "spiece.model"}
SPIECE_UNDERLINE = ""
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
@ -107,6 +111,9 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
Args: Args:
vocab_file (`str`): vocab_file (`str`):
Path to a one-wordpiece-per-line vocabulary file. Path to a one-wordpiece-per-line vocabulary file.
spm_file (`str`, *optional*):
Path to [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm or .model
extension) that contains the vocabulary.
do_lower_case (`bool`, *optional*, defaults to `True`): do_lower_case (`bool`, *optional*, defaults to `True`):
Whether to lower case the input. Only has an effect when do_basic_tokenize=True. Whether to lower case the input. Only has an effect when do_basic_tokenize=True.
do_word_tokenize (`bool`, *optional*, defaults to `True`): do_word_tokenize (`bool`, *optional*, defaults to `True`):
@ -116,7 +123,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
word_tokenizer_type (`str`, *optional*, defaults to `"basic"`): word_tokenizer_type (`str`, *optional*, defaults to `"basic"`):
Type of word tokenizer. Choose from ["basic", "mecab", "sudachi", "jumanpp"]. Type of word tokenizer. Choose from ["basic", "mecab", "sudachi", "jumanpp"].
subword_tokenizer_type (`str`, *optional*, defaults to `"wordpiece"`): subword_tokenizer_type (`str`, *optional*, defaults to `"wordpiece"`):
Type of subword tokenizer. Choose from ["wordpiece", "character"]. Type of subword tokenizer. Choose from ["wordpiece", "character", "sentencepiece",].
mecab_kwargs (`dict`, *optional*): mecab_kwargs (`dict`, *optional*):
Dictionary passed to the `MecabTokenizer` constructor. Dictionary passed to the `MecabTokenizer` constructor.
sudachi_kwargs (`dict`, *optional*): sudachi_kwargs (`dict`, *optional*):
@ -133,6 +140,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
def __init__( def __init__(
self, self,
vocab_file, vocab_file,
spm_file=None,
do_lower_case=False, do_lower_case=False,
do_word_tokenize=True, do_word_tokenize=True,
do_subword_tokenize=True, do_subword_tokenize=True,
@ -150,6 +158,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
**kwargs **kwargs
): ):
super().__init__( super().__init__(
spm_file=spm_file,
unk_token=unk_token, unk_token=unk_token,
sep_token=sep_token, sep_token=sep_token,
pad_token=pad_token, pad_token=pad_token,
@ -167,13 +176,21 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
**kwargs, **kwargs,
) )
if not os.path.isfile(vocab_file): if subword_tokenizer_type == "sentencepiece":
raise ValueError( if not os.path.isfile(spm_file):
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" raise ValueError(
" model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" f"Can't find a vocabulary file at path '{spm_file}'. To load the vocabulary from a Google"
) " pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
self.vocab = load_vocab(vocab_file) )
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) self.spm_file = spm_file
else:
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google"
" pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.do_word_tokenize = do_word_tokenize self.do_word_tokenize = do_word_tokenize
self.word_tokenizer_type = word_tokenizer_type self.word_tokenizer_type = word_tokenizer_type
@ -209,6 +226,8 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
elif subword_tokenizer_type == "character": elif subword_tokenizer_type == "character":
self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=self.unk_token) self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=self.unk_token)
elif subword_tokenizer_type == "sentencepiece":
self.subword_tokenizer = SentencepieceTokenizer(vocab=self.spm_file, unk_token=self.unk_token)
else: else:
raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.") raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.")
@ -251,27 +270,34 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
return split_tokens return split_tokens
@property @property
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size
def vocab_size(self): def vocab_size(self):
if self.subword_tokenizer_type == "sentencepiece":
return len(self.subword_tokenizer.sp_model)
return len(self.vocab) return len(self.vocab)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab
def get_vocab(self): def get_vocab(self):
if self.subword_tokenizer_type == "sentencepiece":
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
return dict(self.vocab, **self.added_tokens_encoder) return dict(self.vocab, **self.added_tokens_encoder)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab.""" """Converts a token (str) in an id using the vocab."""
if self.subword_tokenizer_type == "sentencepiece":
return self.subword_tokenizer.sp_model.PieceToId(token)
return self.vocab.get(token, self.vocab.get(self.unk_token)) return self.vocab.get(token, self.vocab.get(self.unk_token))
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
if self.subword_tokenizer_type == "sentencepiece":
return self.subword_tokenizer.sp_model.IdToPiece(index)
return self.ids_to_tokens.get(index, self.unk_token) return self.ids_to_tokens.get(index, self.unk_token)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
if self.subword_tokenizer_type == "sentencepiece":
return self.subword_tokenizer.sp_model.decode(tokens)
out_string = " ".join(tokens).replace(" ##", "").strip() out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string return out_string
@ -360,25 +386,36 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory): if os.path.isdir(save_directory):
vocab_file = os.path.join( if self.subword_tokenizer_type == "sentencepiece":
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] vocab_file = os.path.join(
) save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["spm_file"]
)
else:
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"],
)
else: else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): if self.subword_tokenizer_type == "sentencepiece":
if index != token_index: with open(vocab_file, "wb") as writer:
logger.warning( content_spiece_model = self.subword_tokenizer.sp_model.serialized_model_proto()
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." writer.write(content_spiece_model)
" Please check that the vocabulary is not corrupted!" else:
) with open(vocab_file, "w", encoding="utf-8") as writer:
index = token_index index = 0
writer.write(token + "\n") for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
index += 1 if index != token_index:
logger.warning(
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!"
)
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,) return (vocab_file,)
@ -893,3 +930,72 @@ class WordpieceTokenizer(object):
else: else:
output_tokens.extend(sub_tokens) output_tokens.extend(sub_tokens)
return output_tokens return output_tokens
class SentencepieceTokenizer(object):
"""
Runs sentencepiece tokenization. Based on transformers.models.albert.tokenization_albert.AlbertTokenizer.
"""
def __init__(
self,
vocab,
unk_token,
do_lower_case=False,
remove_space=True,
keep_accents=True,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
):
self.vocab = vocab
self.unk_token = unk_token
self.do_lower_case = do_lower_case
self.remove_space = remove_space
self.keep_accents = keep_accents
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab)
def preprocess_text(self, inputs):
if self.remove_space:
outputs = " ".join(inputs.strip().split())
else:
outputs = inputs
outputs = outputs.replace("``", '"').replace("''", '"')
if not self.keep_accents:
outputs = unicodedata.normalize("NFKD", outputs)
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
if self.do_lower_case:
outputs = outputs.lower()
return outputs
def tokenize(self, text):
"""
Tokenizes text by sentencepiece. Based on [SentencePiece](https://github.com/google/sentencepiece).
Tokenization needs the given vocabulary.
Args:
text: A string needs to be tokenized.
Returns:
A list of sentencepiece tokens.
"""
text = self.preprocess_text(text)
pieces = self.sp_model.encode(text, out_type=str)
new_pieces = []
for piece in pieces:
if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
else:
cur_pieces[0] = cur_pieces[0][1:]
cur_pieces.append(piece[-1])
new_pieces.extend(cur_pieces)
else:
new_pieces.append(piece)
return new_pieces

View File

@ -334,6 +334,16 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"]) self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"])
def test_sentencepiece_tokenizer(self):
tokenizer = BertJapaneseTokenizer.from_pretrained("nlp-waseda/roberta-base-japanese-with-auto-jumanpp")
subword_tokenizer = tokenizer.subword_tokenizer
tokens = subword_tokenizer.tokenize("国境 の 長い トンネル を 抜ける と 雪国 であった 。")
self.assertListEqual(tokens, ["▁国境", "▁の", "▁長い", "▁トンネル", "▁を", "▁抜ける", "▁と", "▁雪", "", "▁であった", "▁。"])
tokens = subword_tokenizer.tokenize("こんばんは こんばん にち は こんにちは")
self.assertListEqual(tokens, ["▁こん", "ばん", "", "▁こん", "ばん", "▁に", "", "▁は", "▁こんにちは"])
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese") tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese")