fix BertTokenizerFast `tokenize_chinese_chars` arg (#15158)
* add new test * fix in init * more relevant test
This commit is contained in:
parent
4aa16fce6c
commit
51d7ebf260
|
@ -188,15 +188,17 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||
normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||
if (
|
||||
pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
|
||||
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
||||
normalizer_state.get("lowercase", do_lower_case) != do_lower_case
|
||||
or normalizer_state.get("strip_accents", strip_accents) != strip_accents
|
||||
or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
|
||||
):
|
||||
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
||||
pre_tok_state["lowercase"] = do_lower_case
|
||||
pre_tok_state["strip_accents"] = strip_accents
|
||||
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
||||
normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
|
||||
normalizer_state["lowercase"] = do_lower_case
|
||||
normalizer_state["strip_accents"] = strip_accents
|
||||
normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
|
||||
self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
|
||||
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
|
|
|
@ -299,3 +299,40 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
[e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"])
|
||||
)
|
||||
self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
|
||||
|
||||
def test_change_tokenize_chinese_chars(self):
|
||||
list_of_commun_chinese_char = ["的", "人", "有"]
|
||||
text_with_chinese_char = "".join(list_of_commun_chinese_char)
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
|
||||
kwargs["tokenize_chinese_chars"] = True
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
|
||||
ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
|
||||
|
||||
tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
|
||||
tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)
|
||||
|
||||
# it is expected that each Chinese character is not preceded by "##"
|
||||
self.assertListEqual(tokens_without_spe_char_p, list_of_commun_chinese_char)
|
||||
self.assertListEqual(tokens_without_spe_char_r, list_of_commun_chinese_char)
|
||||
|
||||
kwargs["tokenize_chinese_chars"] = False
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
|
||||
ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
|
||||
|
||||
tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
|
||||
tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)
|
||||
|
||||
# it is expected that only the first Chinese character is not preceded by "##".
|
||||
expected_tokens = [
|
||||
f"##{token}" if idx != 0 else token for idx, token in enumerate(list_of_commun_chinese_char)
|
||||
]
|
||||
self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
|
||||
self.assertListEqual(tokens_without_spe_char_r, expected_tokens)
|
||||
|
|
Loading…
Reference in New Issue