Supports already existing special tokens
This commit is contained in:
parent
2f259b228e
commit
cc412edd42
|
@ -321,4 +321,16 @@ class CommonTestCases:
|
||||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||||
assert encoded_sequence == filtered_sequence
|
assert encoded_sequence == filtered_sequence
|
||||||
|
|
||||||
|
# Testing with already existing special tokens
|
||||||
|
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
|
||||||
|
tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'})
|
||||||
|
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
|
||||||
|
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||||
|
sequence_ids_orig = encoded_sequence_dict["sequence_ids"]
|
||||||
|
sequence_ids = tokenizer.get_sequence_ids(encoded_sequence_w_special, special_tokens_present=True)
|
||||||
|
assert len(sequence_ids) == len(encoded_sequence_w_special)
|
||||||
|
print(sequence_ids_orig, sequence_ids)
|
||||||
|
assert sequence_ids_orig == sequence_ids
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -204,7 +204,7 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||||
|
|
||||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||||
|
|
||||||
def get_sequence_ids(self, token_ids_0, token_ids_1=None):
|
def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
|
||||||
"""
|
"""
|
||||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
||||||
|
@ -217,6 +217,10 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||||
Returns:
|
Returns:
|
||||||
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
|
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if special_tokens_present:
|
||||||
|
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1:
|
||||||
return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0]
|
return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -100,7 +100,7 @@ class RobertaTokenizer(GPT2Tokenizer):
|
||||||
cls = [self.cls_token_id]
|
cls = [self.cls_token_id]
|
||||||
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
|
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
|
||||||
|
|
||||||
def get_sequence_ids(self, token_ids_0, token_ids_1=None):
|
def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
|
||||||
"""
|
"""
|
||||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
||||||
|
@ -113,6 +113,10 @@ class RobertaTokenizer(GPT2Tokenizer):
|
||||||
Returns:
|
Returns:
|
||||||
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
|
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if special_tokens_present:
|
||||||
|
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1:
|
||||||
return [0] + ([1] * len(token_ids_0)) + [0, 0] + ([1] * len(token_ids_1)) + [0]
|
return [0] + ([1] * len(token_ids_0)) + [0, 0] + ([1] * len(token_ids_1)) + [0]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -908,7 +908,7 @@ class PreTrainedTokenizer(object):
|
||||||
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
|
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
|
||||||
return token_ids_0 + token_ids_1
|
return token_ids_0 + token_ids_1
|
||||||
|
|
||||||
def get_sequence_ids(self, token_ids_0, token_ids_1=None):
|
def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
|
||||||
return [1] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
|
return [1] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||||
|
|
|
@ -770,7 +770,7 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||||
cls = [self.cls_token_id]
|
cls = [self.cls_token_id]
|
||||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||||
|
|
||||||
def get_sequence_ids(self, token_ids_0, token_ids_1=None):
|
def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
|
||||||
"""
|
"""
|
||||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
||||||
|
@ -783,6 +783,10 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||||
Returns:
|
Returns:
|
||||||
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
|
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if special_tokens_present:
|
||||||
|
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1:
|
||||||
return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0]
|
return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -200,7 +200,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||||
cls = [self.cls_token_id]
|
cls = [self.cls_token_id]
|
||||||
return token_ids_0 + sep + token_ids_1 + sep + cls
|
return token_ids_0 + sep + token_ids_1 + sep + cls
|
||||||
|
|
||||||
def get_sequence_ids(self, token_ids_0, token_ids_1=None):
|
def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
|
||||||
"""
|
"""
|
||||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
||||||
|
@ -213,6 +213,10 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||||
Returns:
|
Returns:
|
||||||
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
|
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if special_tokens_present:
|
||||||
|
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1:
|
||||||
return ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0, 0]
|
return ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0, 0]
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue