Supports already existing special tokens

This commit is contained in:
LysandreJik 2019-09-30 14:11:41 -04:00
parent 2f259b228e
commit cc412edd42
6 changed files with 33 additions and 5 deletions

View File

@ -321,4 +321,16 @@ class CommonTestCases:
filtered_sequence = [x for x in filtered_sequence if x is not None] filtered_sequence = [x for x in filtered_sequence if x is not None]
assert encoded_sequence == filtered_sequence assert encoded_sequence == filtered_sequence
# Testing with already existing special tokens
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'})
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
sequence_ids_orig = encoded_sequence_dict["sequence_ids"]
sequence_ids = tokenizer.get_sequence_ids(encoded_sequence_w_special, special_tokens_present=True)
assert len(sequence_ids) == len(encoded_sequence_w_special)
print(sequence_ids_orig, sequence_ids)
assert sequence_ids_orig == sequence_ids

View File

@ -204,7 +204,7 @@ class BertTokenizer(PreTrainedTokenizer):
return cls + token_ids_0 + sep + token_ids_1 + sep return cls + token_ids_0 + sep + token_ids_1 + sep
def get_sequence_ids(self, token_ids_0, token_ids_1=None): def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
""" """
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
@ -217,6 +217,10 @@ class BertTokenizer(PreTrainedTokenizer):
Returns: Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
""" """
if special_tokens_present:
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
if token_ids_1: if token_ids_1:
return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0] return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0]
else: else:

View File

@ -100,7 +100,7 @@ class RobertaTokenizer(GPT2Tokenizer):
cls = [self.cls_token_id] cls = [self.cls_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def get_sequence_ids(self, token_ids_0, token_ids_1=None): def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
""" """
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
@ -113,6 +113,10 @@ class RobertaTokenizer(GPT2Tokenizer):
Returns: Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
""" """
if special_tokens_present:
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
if token_ids_1: if token_ids_1:
return [0] + ([1] * len(token_ids_0)) + [0, 0] + ([1] * len(token_ids_1)) + [0] return [0] + ([1] * len(token_ids_0)) + [0, 0] + ([1] * len(token_ids_1)) + [0]
else: else:

View File

@ -908,7 +908,7 @@ class PreTrainedTokenizer(object):
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.") logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
return token_ids_0 + token_ids_1 return token_ids_0 + token_ids_1
def get_sequence_ids(self, token_ids_0, token_ids_1=None): def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
return [1] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) return [1] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
def convert_ids_to_tokens(self, ids, skip_special_tokens=False): def convert_ids_to_tokens(self, ids, skip_special_tokens=False):

View File

@ -770,7 +770,7 @@ class XLMTokenizer(PreTrainedTokenizer):
cls = [self.cls_token_id] cls = [self.cls_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep return cls + token_ids_0 + sep + token_ids_1 + sep
def get_sequence_ids(self, token_ids_0, token_ids_1=None): def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
""" """
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
@ -783,6 +783,10 @@ class XLMTokenizer(PreTrainedTokenizer):
Returns: Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
""" """
if special_tokens_present:
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
if token_ids_1: if token_ids_1:
return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0] return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0]
else: else:

View File

@ -200,7 +200,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
cls = [self.cls_token_id] cls = [self.cls_token_id]
return token_ids_0 + sep + token_ids_1 + sep + cls return token_ids_0 + sep + token_ids_1 + sep + cls
def get_sequence_ids(self, token_ids_0, token_ids_1=None): def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
""" """
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
@ -213,6 +213,10 @@ class XLNetTokenizer(PreTrainedTokenizer):
Returns: Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
""" """
if special_tokens_present:
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
if token_ids_1: if token_ids_1:
return ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0, 0] return ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0, 0]
else: else: