Add sentencepiece to BertJapaneseTokenizer (#19769)
* support sentencepiece for bertjapanesetokenizer * add test vocab file for sentencepiece, bertjapanesetokenizer * make BasicTokenizer be identical to transformers.models.bert.tokenization_bert.BasicTokenizer * fix missing of \n in comment * fix init argument missing in tests * make spm_file be optional, exclude spiece.model from tests/fixtures, and add description comments * make comment length less than 119 * apply doc style check
This commit is contained in:
parent
2ebf4e6a7b
commit
31565ff0fd
|
@ -19,7 +19,9 @@ import collections
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import unicodedata
|
import unicodedata
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
@ -27,7 +29,9 @@ from ...utils import logging
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "spm_file": "spiece.model"}
|
||||||
|
|
||||||
|
SPIECE_UNDERLINE = "▁"
|
||||||
|
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
"vocab_file": {
|
"vocab_file": {
|
||||||
|
@ -107,6 +111,9 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
||||||
Args:
|
Args:
|
||||||
vocab_file (`str`):
|
vocab_file (`str`):
|
||||||
Path to a one-wordpiece-per-line vocabulary file.
|
Path to a one-wordpiece-per-line vocabulary file.
|
||||||
|
spm_file (`str`, *optional*):
|
||||||
|
Path to [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm or .model
|
||||||
|
extension) that contains the vocabulary.
|
||||||
do_lower_case (`bool`, *optional*, defaults to `True`):
|
do_lower_case (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to lower case the input. Only has an effect when do_basic_tokenize=True.
|
Whether to lower case the input. Only has an effect when do_basic_tokenize=True.
|
||||||
do_word_tokenize (`bool`, *optional*, defaults to `True`):
|
do_word_tokenize (`bool`, *optional*, defaults to `True`):
|
||||||
|
@ -116,7 +123,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
||||||
word_tokenizer_type (`str`, *optional*, defaults to `"basic"`):
|
word_tokenizer_type (`str`, *optional*, defaults to `"basic"`):
|
||||||
Type of word tokenizer. Choose from ["basic", "mecab", "sudachi", "jumanpp"].
|
Type of word tokenizer. Choose from ["basic", "mecab", "sudachi", "jumanpp"].
|
||||||
subword_tokenizer_type (`str`, *optional*, defaults to `"wordpiece"`):
|
subword_tokenizer_type (`str`, *optional*, defaults to `"wordpiece"`):
|
||||||
Type of subword tokenizer. Choose from ["wordpiece", "character"].
|
Type of subword tokenizer. Choose from ["wordpiece", "character", "sentencepiece",].
|
||||||
mecab_kwargs (`dict`, *optional*):
|
mecab_kwargs (`dict`, *optional*):
|
||||||
Dictionary passed to the `MecabTokenizer` constructor.
|
Dictionary passed to the `MecabTokenizer` constructor.
|
||||||
sudachi_kwargs (`dict`, *optional*):
|
sudachi_kwargs (`dict`, *optional*):
|
||||||
|
@ -133,6 +140,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_file,
|
vocab_file,
|
||||||
|
spm_file=None,
|
||||||
do_lower_case=False,
|
do_lower_case=False,
|
||||||
do_word_tokenize=True,
|
do_word_tokenize=True,
|
||||||
do_subword_tokenize=True,
|
do_subword_tokenize=True,
|
||||||
|
@ -150,6 +158,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
spm_file=spm_file,
|
||||||
unk_token=unk_token,
|
unk_token=unk_token,
|
||||||
sep_token=sep_token,
|
sep_token=sep_token,
|
||||||
pad_token=pad_token,
|
pad_token=pad_token,
|
||||||
|
@ -167,13 +176,21 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.path.isfile(vocab_file):
|
if subword_tokenizer_type == "sentencepiece":
|
||||||
raise ValueError(
|
if not os.path.isfile(spm_file):
|
||||||
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
|
raise ValueError(
|
||||||
" model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
f"Can't find a vocabulary file at path '{spm_file}'. To load the vocabulary from a Google"
|
||||||
)
|
" pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||||
self.vocab = load_vocab(vocab_file)
|
)
|
||||||
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
self.spm_file = spm_file
|
||||||
|
else:
|
||||||
|
if not os.path.isfile(vocab_file):
|
||||||
|
raise ValueError(
|
||||||
|
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google"
|
||||||
|
" pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||||
|
)
|
||||||
|
self.vocab = load_vocab(vocab_file)
|
||||||
|
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
||||||
|
|
||||||
self.do_word_tokenize = do_word_tokenize
|
self.do_word_tokenize = do_word_tokenize
|
||||||
self.word_tokenizer_type = word_tokenizer_type
|
self.word_tokenizer_type = word_tokenizer_type
|
||||||
|
@ -209,6 +226,8 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
||||||
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
||||||
elif subword_tokenizer_type == "character":
|
elif subword_tokenizer_type == "character":
|
||||||
self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
||||||
|
elif subword_tokenizer_type == "sentencepiece":
|
||||||
|
self.subword_tokenizer = SentencepieceTokenizer(vocab=self.spm_file, unk_token=self.unk_token)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.")
|
raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.")
|
||||||
|
|
||||||
|
@ -251,27 +270,34 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
||||||
return split_tokens
|
return split_tokens
|
||||||
|
|
||||||
@property
|
@property
|
||||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size
|
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
|
if self.subword_tokenizer_type == "sentencepiece":
|
||||||
|
return len(self.subword_tokenizer.sp_model)
|
||||||
return len(self.vocab)
|
return len(self.vocab)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab
|
|
||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
|
if self.subword_tokenizer_type == "sentencepiece":
|
||||||
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||||
|
vocab.update(self.added_tokens_encoder)
|
||||||
|
return vocab
|
||||||
return dict(self.vocab, **self.added_tokens_encoder)
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
"""Converts a token (str) in an id using the vocab."""
|
"""Converts a token (str) in an id using the vocab."""
|
||||||
|
if self.subword_tokenizer_type == "sentencepiece":
|
||||||
|
return self.subword_tokenizer.sp_model.PieceToId(token)
|
||||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||||
|
|
||||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token
|
|
||||||
def _convert_id_to_token(self, index):
|
def _convert_id_to_token(self, index):
|
||||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
|
if self.subword_tokenizer_type == "sentencepiece":
|
||||||
|
return self.subword_tokenizer.sp_model.IdToPiece(index)
|
||||||
return self.ids_to_tokens.get(index, self.unk_token)
|
return self.ids_to_tokens.get(index, self.unk_token)
|
||||||
|
|
||||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string
|
|
||||||
def convert_tokens_to_string(self, tokens):
|
def convert_tokens_to_string(self, tokens):
|
||||||
"""Converts a sequence of tokens (string) in a single string."""
|
"""Converts a sequence of tokens (string) in a single string."""
|
||||||
|
if self.subword_tokenizer_type == "sentencepiece":
|
||||||
|
return self.subword_tokenizer.sp_model.decode(tokens)
|
||||||
out_string = " ".join(tokens).replace(" ##", "").strip()
|
out_string = " ".join(tokens).replace(" ##", "").strip()
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
|
@ -360,25 +386,36 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
||||||
return len(cls + token_ids_0 + sep) * [0]
|
return len(cls + token_ids_0 + sep) * [0]
|
||||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||||
|
|
||||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
index = 0
|
|
||||||
if os.path.isdir(save_directory):
|
if os.path.isdir(save_directory):
|
||||||
vocab_file = os.path.join(
|
if self.subword_tokenizer_type == "sentencepiece":
|
||||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
vocab_file = os.path.join(
|
||||||
)
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["spm_file"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vocab_file = os.path.join(
|
||||||
|
save_directory,
|
||||||
|
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
||||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
|
||||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
if self.subword_tokenizer_type == "sentencepiece":
|
||||||
if index != token_index:
|
with open(vocab_file, "wb") as writer:
|
||||||
logger.warning(
|
content_spiece_model = self.subword_tokenizer.sp_model.serialized_model_proto()
|
||||||
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
writer.write(content_spiece_model)
|
||||||
" Please check that the vocabulary is not corrupted!"
|
else:
|
||||||
)
|
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||||
index = token_index
|
index = 0
|
||||||
writer.write(token + "\n")
|
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||||
index += 1
|
if index != token_index:
|
||||||
|
logger.warning(
|
||||||
|
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
||||||
|
" Please check that the vocabulary is not corrupted!"
|
||||||
|
)
|
||||||
|
index = token_index
|
||||||
|
writer.write(token + "\n")
|
||||||
|
index += 1
|
||||||
return (vocab_file,)
|
return (vocab_file,)
|
||||||
|
|
||||||
|
|
||||||
|
@ -893,3 +930,72 @@ class WordpieceTokenizer(object):
|
||||||
else:
|
else:
|
||||||
output_tokens.extend(sub_tokens)
|
output_tokens.extend(sub_tokens)
|
||||||
return output_tokens
|
return output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class SentencepieceTokenizer(object):
|
||||||
|
"""
|
||||||
|
Runs sentencepiece tokenization. Based on transformers.models.albert.tokenization_albert.AlbertTokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab,
|
||||||
|
unk_token,
|
||||||
|
do_lower_case=False,
|
||||||
|
remove_space=True,
|
||||||
|
keep_accents=True,
|
||||||
|
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.unk_token = unk_token
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
self.remove_space = remove_space
|
||||||
|
self.keep_accents = keep_accents
|
||||||
|
|
||||||
|
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||||
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
|
self.sp_model.Load(self.vocab)
|
||||||
|
|
||||||
|
def preprocess_text(self, inputs):
|
||||||
|
if self.remove_space:
|
||||||
|
outputs = " ".join(inputs.strip().split())
|
||||||
|
else:
|
||||||
|
outputs = inputs
|
||||||
|
outputs = outputs.replace("``", '"').replace("''", '"')
|
||||||
|
|
||||||
|
if not self.keep_accents:
|
||||||
|
outputs = unicodedata.normalize("NFKD", outputs)
|
||||||
|
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
|
||||||
|
if self.do_lower_case:
|
||||||
|
outputs = outputs.lower()
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
"""
|
||||||
|
Tokenizes text by sentencepiece. Based on [SentencePiece](https://github.com/google/sentencepiece).
|
||||||
|
Tokenization needs the given vocabulary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: A string needs to be tokenized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of sentencepiece tokens.
|
||||||
|
"""
|
||||||
|
text = self.preprocess_text(text)
|
||||||
|
pieces = self.sp_model.encode(text, out_type=str)
|
||||||
|
new_pieces = []
|
||||||
|
for piece in pieces:
|
||||||
|
if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
|
||||||
|
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
|
||||||
|
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
|
||||||
|
if len(cur_pieces[0]) == 1:
|
||||||
|
cur_pieces = cur_pieces[1:]
|
||||||
|
else:
|
||||||
|
cur_pieces[0] = cur_pieces[0][1:]
|
||||||
|
cur_pieces.append(piece[-1])
|
||||||
|
new_pieces.extend(cur_pieces)
|
||||||
|
else:
|
||||||
|
new_pieces.append(piece)
|
||||||
|
|
||||||
|
return new_pieces
|
||||||
|
|
|
@ -334,6 +334,16 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"])
|
self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"])
|
||||||
|
|
||||||
|
def test_sentencepiece_tokenizer(self):
|
||||||
|
tokenizer = BertJapaneseTokenizer.from_pretrained("nlp-waseda/roberta-base-japanese-with-auto-jumanpp")
|
||||||
|
subword_tokenizer = tokenizer.subword_tokenizer
|
||||||
|
|
||||||
|
tokens = subword_tokenizer.tokenize("国境 の 長い トンネル を 抜ける と 雪国 であった 。")
|
||||||
|
self.assertListEqual(tokens, ["▁国境", "▁の", "▁長い", "▁トンネル", "▁を", "▁抜ける", "▁と", "▁雪", "国", "▁であった", "▁。"])
|
||||||
|
|
||||||
|
tokens = subword_tokenizer.tokenize("こんばんは こんばん にち は こんにちは")
|
||||||
|
self.assertListEqual(tokens, ["▁こん", "ばん", "は", "▁こん", "ばん", "▁に", "ち", "▁は", "▁こんにちは"])
|
||||||
|
|
||||||
def test_sequence_builders(self):
|
def test_sequence_builders(self):
|
||||||
tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese")
|
tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue