Add sentencepiece to BertJapaneseTokenizer (#19769)
* support sentencepiece for bertjapanesetokenizer * add test vocab file for sentencepiece, bertjapanesetokenizer * make BasicTokenizer be identical to transformers.models.bert.tokenization_bert.BasicTokenizer * fix missing of \n in comment * fix init argument missing in tests * make spm_file be optional, exclude spiece.model from tests/fixtures, and add description comments * make comment length less than 119 * apply doc style check
This commit is contained in:
parent
2ebf4e6a7b
commit
31565ff0fd
|
@ -19,7 +19,9 @@ import collections
|
|||
import copy
|
||||
import os
|
||||
import unicodedata
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
||||
from ...utils import logging
|
||||
|
@ -27,7 +29,9 @@ from ...utils import logging
|
|||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "spm_file": "spiece.model"}
|
||||
|
||||
SPIECE_UNDERLINE = "▁"
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
|
@ -107,6 +111,9 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
|||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to a one-wordpiece-per-line vocabulary file.
|
||||
spm_file (`str`, *optional*):
|
||||
Path to [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm or .model
|
||||
extension) that contains the vocabulary.
|
||||
do_lower_case (`bool`, *optional*, defaults to `True`):
|
||||
Whether to lower case the input. Only has an effect when do_basic_tokenize=True.
|
||||
do_word_tokenize (`bool`, *optional*, defaults to `True`):
|
||||
|
@ -116,7 +123,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
|||
word_tokenizer_type (`str`, *optional*, defaults to `"basic"`):
|
||||
Type of word tokenizer. Choose from ["basic", "mecab", "sudachi", "jumanpp"].
|
||||
subword_tokenizer_type (`str`, *optional*, defaults to `"wordpiece"`):
|
||||
Type of subword tokenizer. Choose from ["wordpiece", "character"].
|
||||
Type of subword tokenizer. Choose from ["wordpiece", "character", "sentencepiece",].
|
||||
mecab_kwargs (`dict`, *optional*):
|
||||
Dictionary passed to the `MecabTokenizer` constructor.
|
||||
sudachi_kwargs (`dict`, *optional*):
|
||||
|
@ -133,6 +140,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
|||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
spm_file=None,
|
||||
do_lower_case=False,
|
||||
do_word_tokenize=True,
|
||||
do_subword_tokenize=True,
|
||||
|
@ -150,6 +158,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
|||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
spm_file=spm_file,
|
||||
unk_token=unk_token,
|
||||
sep_token=sep_token,
|
||||
pad_token=pad_token,
|
||||
|
@ -167,13 +176,21 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
|
||||
" model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
||||
if subword_tokenizer_type == "sentencepiece":
|
||||
if not os.path.isfile(spm_file):
|
||||
raise ValueError(
|
||||
f"Can't find a vocabulary file at path '{spm_file}'. To load the vocabulary from a Google"
|
||||
" pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
self.spm_file = spm_file
|
||||
else:
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google"
|
||||
" pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
||||
|
||||
self.do_word_tokenize = do_word_tokenize
|
||||
self.word_tokenizer_type = word_tokenizer_type
|
||||
|
@ -209,6 +226,8 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
|||
self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
||||
elif subword_tokenizer_type == "character":
|
||||
self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
||||
elif subword_tokenizer_type == "sentencepiece":
|
||||
self.subword_tokenizer = SentencepieceTokenizer(vocab=self.spm_file, unk_token=self.unk_token)
|
||||
else:
|
||||
raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.")
|
||||
|
||||
|
@ -251,27 +270,34 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
|||
return split_tokens
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size
|
||||
def vocab_size(self):
|
||||
if self.subword_tokenizer_type == "sentencepiece":
|
||||
return len(self.subword_tokenizer.sp_model)
|
||||
return len(self.vocab)
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab
|
||||
def get_vocab(self):
|
||||
if self.subword_tokenizer_type == "sentencepiece":
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
if self.subword_tokenizer_type == "sentencepiece":
|
||||
return self.subword_tokenizer.sp_model.PieceToId(token)
|
||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
if self.subword_tokenizer_type == "sentencepiece":
|
||||
return self.subword_tokenizer.sp_model.IdToPiece(index)
|
||||
return self.ids_to_tokens.get(index, self.unk_token)
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
if self.subword_tokenizer_type == "sentencepiece":
|
||||
return self.subword_tokenizer.sp_model.decode(tokens)
|
||||
out_string = " ".join(tokens).replace(" ##", "").strip()
|
||||
return out_string
|
||||
|
||||
|
@ -360,25 +386,36 @@ class BertJapaneseTokenizer(PreTrainedTokenizer):
|
|||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
index = 0
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
if self.subword_tokenizer_type == "sentencepiece":
|
||||
vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["spm_file"]
|
||||
)
|
||||
else:
|
||||
vocab_file = os.path.join(
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"],
|
||||
)
|
||||
else:
|
||||
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning(
|
||||
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!"
|
||||
)
|
||||
index = token_index
|
||||
writer.write(token + "\n")
|
||||
index += 1
|
||||
|
||||
if self.subword_tokenizer_type == "sentencepiece":
|
||||
with open(vocab_file, "wb") as writer:
|
||||
content_spiece_model = self.subword_tokenizer.sp_model.serialized_model_proto()
|
||||
writer.write(content_spiece_model)
|
||||
else:
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
index = 0
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning(
|
||||
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!"
|
||||
)
|
||||
index = token_index
|
||||
writer.write(token + "\n")
|
||||
index += 1
|
||||
return (vocab_file,)
|
||||
|
||||
|
||||
|
@ -893,3 +930,72 @@ class WordpieceTokenizer(object):
|
|||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
class SentencepieceTokenizer(object):
|
||||
"""
|
||||
Runs sentencepiece tokenization. Based on transformers.models.albert.tokenization_albert.AlbertTokenizer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab,
|
||||
unk_token,
|
||||
do_lower_case=False,
|
||||
remove_space=True,
|
||||
keep_accents=True,
|
||||
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.do_lower_case = do_lower_case
|
||||
self.remove_space = remove_space
|
||||
self.keep_accents = keep_accents
|
||||
|
||||
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(self.vocab)
|
||||
|
||||
def preprocess_text(self, inputs):
|
||||
if self.remove_space:
|
||||
outputs = " ".join(inputs.strip().split())
|
||||
else:
|
||||
outputs = inputs
|
||||
outputs = outputs.replace("``", '"').replace("''", '"')
|
||||
|
||||
if not self.keep_accents:
|
||||
outputs = unicodedata.normalize("NFKD", outputs)
|
||||
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
|
||||
if self.do_lower_case:
|
||||
outputs = outputs.lower()
|
||||
|
||||
return outputs
|
||||
|
||||
def tokenize(self, text):
|
||||
"""
|
||||
Tokenizes text by sentencepiece. Based on [SentencePiece](https://github.com/google/sentencepiece).
|
||||
Tokenization needs the given vocabulary.
|
||||
|
||||
Args:
|
||||
text: A string needs to be tokenized.
|
||||
|
||||
Returns:
|
||||
A list of sentencepiece tokens.
|
||||
"""
|
||||
text = self.preprocess_text(text)
|
||||
pieces = self.sp_model.encode(text, out_type=str)
|
||||
new_pieces = []
|
||||
for piece in pieces:
|
||||
if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
|
||||
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
|
||||
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
|
||||
if len(cur_pieces[0]) == 1:
|
||||
cur_pieces = cur_pieces[1:]
|
||||
else:
|
||||
cur_pieces[0] = cur_pieces[0][1:]
|
||||
cur_pieces.append(piece[-1])
|
||||
new_pieces.extend(cur_pieces)
|
||||
else:
|
||||
new_pieces.append(piece)
|
||||
|
||||
return new_pieces
|
||||
|
|
|
@ -334,6 +334,16 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
|
||||
self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"])
|
||||
|
||||
def test_sentencepiece_tokenizer(self):
|
||||
tokenizer = BertJapaneseTokenizer.from_pretrained("nlp-waseda/roberta-base-japanese-with-auto-jumanpp")
|
||||
subword_tokenizer = tokenizer.subword_tokenizer
|
||||
|
||||
tokens = subword_tokenizer.tokenize("国境 の 長い トンネル を 抜ける と 雪国 であった 。")
|
||||
self.assertListEqual(tokens, ["▁国境", "▁の", "▁長い", "▁トンネル", "▁を", "▁抜ける", "▁と", "▁雪", "国", "▁であった", "▁。"])
|
||||
|
||||
tokens = subword_tokenizer.tokenize("こんばんは こんばん にち は こんにちは")
|
||||
self.assertListEqual(tokens, ["▁こん", "ばん", "は", "▁こん", "ばん", "▁に", "ち", "▁は", "▁こんにちは"])
|
||||
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese")
|
||||
|
||||
|
|
Loading…
Reference in New Issue