From 6c1d0bc0665ef01710db301fb1a0a3c23778714a Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 4 Oct 2019 17:38:38 -0400 Subject: [PATCH] update encode_plus - add truncation strategies --- examples/run_lm_finetuning.py | 8 +- transformers/tests/tokenization_bert_test.py | 4 +- .../tests/tokenization_distilbert_test.py | 4 +- .../tests/tokenization_roberta_test.py | 4 +- .../tests/tokenization_tests_commons.py | 51 +++--- transformers/tests/tokenization_xlm_test.py | 4 +- transformers/tests/tokenization_xlnet_test.py | 4 +- transformers/tokenization_bert.py | 48 +++--- transformers/tokenization_roberta.py | 49 +++--- transformers/tokenization_utils.py | 147 ++++++++++-------- transformers/tokenization_xlm.py | 45 +++--- transformers/tokenization_xlnet.py | 47 +++--- 12 files changed, 222 insertions(+), 193 deletions(-) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 585bbc8d75..6c5e749868 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -75,7 +75,7 @@ class TextDataset(Dataset): tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) for i in range(0, len(tokenized_text)-block_size+1, block_size): # Truncate in block of block_size - self.examples.append(tokenizer.add_special_tokens_single_sequence(tokenized_text[i:i+block_size])) + self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i+block_size])) # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) # If your dataset is small, first you should loook for a bigger one :-) and second you # can change this behavior by adding (model specific) padding. @@ -109,10 +109,8 @@ def mask_tokens(inputs, tokenizer, args): labels = inputs.clone() # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) probability_matrix = torch.full(labels.shape, args.mlm_probability) - probability_matrix *= torch.tensor( - [tokenizer.get_special_tokens_mask(val, special_tokens_present=True) for val in labels.tolist()], - dtype=torch.float - ) + special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()] + probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) masked_indices = torch.bernoulli(probability_matrix).bool() labels[~masked_indices] = -1 # We only compute loss on masked tokens diff --git a/transformers/tests/tokenization_bert_test.py b/transformers/tests/tokenization_bert_test.py index b70941f884..5e49e2915b 100644 --- a/transformers/tests/tokenization_bert_test.py +++ b/transformers/tests/tokenization_bert_test.py @@ -131,8 +131,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): text = tokenizer.encode("sequence builders") text_2 = tokenizer.encode("multi-sequence build") - encoded_sentence = tokenizer.add_special_tokens_single_sequence(text) - encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2) + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) assert encoded_sentence == [101] + text + [102] assert encoded_pair == [101] + text + [102] + text_2 + [102] diff --git a/transformers/tests/tokenization_distilbert_test.py b/transformers/tests/tokenization_distilbert_test.py index 64a88df99f..a18d644fe8 100644 --- a/transformers/tests/tokenization_distilbert_test.py +++ b/transformers/tests/tokenization_distilbert_test.py @@ -36,8 +36,8 @@ class DistilBertTokenizationTest(BertTokenizationTest): text = tokenizer.encode("sequence builders") text_2 = tokenizer.encode("multi-sequence build") - encoded_sentence = tokenizer.add_special_tokens_single_sequence(text) - encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2) + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + \ diff --git a/transformers/tests/tokenization_roberta_test.py b/transformers/tests/tokenization_roberta_test.py index f14b26a2e4..a731ac26c9 100644 --- a/transformers/tests/tokenization_roberta_test.py +++ b/transformers/tests/tokenization_roberta_test.py @@ -87,8 +87,8 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) - encoded_sentence = tokenizer.add_special_tokens_single_sequence(text) - encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2) + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) assert encoded_sentence == encoded_text_from_decode assert encoded_pair == encoded_pair_from_decode diff --git a/transformers/tests/tokenization_tests_commons.py b/transformers/tests/tokenization_tests_commons.py index c902347fe5..b8f9295633 100644 --- a/transformers/tests/tokenization_tests_commons.py +++ b/transformers/tests/tokenization_tests_commons.py @@ -193,12 +193,12 @@ class CommonTestCases: tokenizer = self.get_tokenizer() - if tokenizer.add_special_tokens_sequence_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer": + if tokenizer.build_inputs_with_special_tokens.__qualname__.split('.')[0] != "PreTrainedTokenizer": seq_0 = "Test this method." seq_1 = "With these inputs." information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True) sequences, mask = information["input_ids"], information["token_type_ids"] - assert len(sequences) == len(mask) + self.assertEqual(len(sequences), len(mask)) def test_number_of_added_tokens(self): tokenizer = self.get_tokenizer() @@ -211,7 +211,7 @@ class CommonTestCases: # Method is implemented (e.g. not GPT-2) if len(attached_sequences) != 2: - assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - len(sequences) + self.assertEqual(tokenizer.num_added_tokens(pair=True), len(attached_sequences) - len(sequences)) def test_maximum_encoding_length_single_input(self): tokenizer = self.get_tokenizer() @@ -227,10 +227,10 @@ class CommonTestCases: truncated_sequence = information["input_ids"] overflowing_tokens = information["overflowing_tokens"] - assert len(overflowing_tokens) == 2 + stride - assert overflowing_tokens == sequence[-(2 + stride):] - assert len(truncated_sequence) == total_length - 2 - assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2]) + self.assertEqual(len(overflowing_tokens), 2 + stride) + self.assertEqual(overflowing_tokens, sequence[-(2 + stride):]) + self.assertEqual(len(truncated_sequence), total_length - 2) + self.assertEqual(truncated_sequence, tokenizer.build_inputs_with_special_tokens(sequence[:-2])) def test_maximum_encoding_length_pair_input(self): tokenizer = self.get_tokenizer() @@ -243,7 +243,7 @@ class CommonTestCases: sequence_1_no_special_tokens = tokenizer.encode(seq_1) sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True) - truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair( + truncated_second_sequence = tokenizer.build_inputs_with_special_tokens( tokenizer.encode(seq_0), tokenizer.encode(seq_1)[:-2] ) @@ -258,11 +258,11 @@ class CommonTestCases: overflowing_tokens = information["overflowing_tokens"] overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"] - assert len(overflowing_tokens) == 2 + stride - assert overflowing_tokens == sequence_1_no_special_tokens[-(2 + stride):] - assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):] - assert len(truncated_sequence) == len(sequence) - 2 - assert truncated_sequence == truncated_second_sequence + self.assertEqual(len(overflowing_tokens), 2 + stride) + self.assertEqual(overflowing_tokens, sequence_1_no_special_tokens[-(2 + stride):]) + self.assertEqual(overflowing_tokens_first_truncated, sequence_0_no_special_tokens[-(2 + stride):]) + self.assertEqual(len(truncated_sequence), len(sequence) - 2) + self.assertEqual(truncated_sequence, truncated_second_sequence) def test_encode_input_type(self): tokenizer = self.get_tokenizer() @@ -273,8 +273,8 @@ class CommonTestCases: input_ids = tokenizer.convert_tokens_to_ids(tokens) formatted_input = tokenizer.encode(sequence, add_special_tokens=True) - assert tokenizer.encode(tokens, add_special_tokens=True) == formatted_input - assert tokenizer.encode(input_ids, add_special_tokens=True) == formatted_input + self.assertEqual(tokenizer.encode(tokens, add_special_tokens=True), formatted_input) + self.assertEqual(tokenizer.encode(input_ids, add_special_tokens=True), formatted_input) def test_special_tokens_mask(self): tokenizer = self.get_tokenizer() @@ -287,22 +287,22 @@ class CommonTestCases: encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True) encoded_sequence_w_special = encoded_sequence_dict["input_ids"] special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] - assert len(special_tokens_mask) == len(encoded_sequence_w_special) + self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) - filtered_sequence = [(x if special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)] + filtered_sequence = [(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)] filtered_sequence = [x for x in filtered_sequence if x is not None] - assert encoded_sequence == filtered_sequence + self.assertEqual(encoded_sequence, filtered_sequence) # Testing inputs pairs encoded_sequence = tokenizer.encode(sequence_0) + tokenizer.encode(sequence_1) encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True) encoded_sequence_w_special = encoded_sequence_dict["input_ids"] special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] - assert len(special_tokens_mask) == len(encoded_sequence_w_special) + self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) - filtered_sequence = [(x if special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)] + filtered_sequence = [(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)] filtered_sequence = [x for x in filtered_sequence if x is not None] - assert encoded_sequence == filtered_sequence + self.assertEqual(encoded_sequence, filtered_sequence) # Testing with already existing special tokens if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id: @@ -310,9 +310,6 @@ class CommonTestCases: encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True) encoded_sequence_w_special = encoded_sequence_dict["input_ids"] special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"] - special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, special_tokens_present=True) - assert len(special_tokens_mask) == len(encoded_sequence_w_special) - assert special_tokens_mask_orig == special_tokens_mask - - - + special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, already_has_special_tokens=True) + self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) + self.assertEqual(special_tokens_mask_orig, special_tokens_mask) diff --git a/transformers/tests/tokenization_xlm_test.py b/transformers/tests/tokenization_xlm_test.py index b1a71ede59..0949b0cce4 100644 --- a/transformers/tests/tokenization_xlm_test.py +++ b/transformers/tests/tokenization_xlm_test.py @@ -72,8 +72,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): text = tokenizer.encode("sequence builders") text_2 = tokenizer.encode("multi-sequence build") - encoded_sentence = tokenizer.add_special_tokens_single_sequence(text) - encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2) + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) assert encoded_sentence == [1] + text + [1] assert encoded_pair == [1] + text + [1] + text_2 + [1] diff --git a/transformers/tests/tokenization_xlnet_test.py b/transformers/tests/tokenization_xlnet_test.py index f4418c7fe5..1a5dbcf6df 100644 --- a/transformers/tests/tokenization_xlnet_test.py +++ b/transformers/tests/tokenization_xlnet_test.py @@ -95,8 +95,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): text = tokenizer.encode("sequence builders") text_2 = tokenizer.encode("multi-sequence build") - encoded_sentence = tokenizer.add_special_tokens_single_sequence(text) - encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2) + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) assert encoded_sentence == text + [4, 3] assert encoded_pair == text + [4] + text_2 + [4, 3] diff --git a/transformers/tokenization_bert.py b/transformers/tokenization_bert.py index cf6ab51d02..d256f27a58 100644 --- a/transformers/tokenization_bert.py +++ b/transformers/tokenization_bert.py @@ -187,24 +187,21 @@ class BertTokenizer(PreTrainedTokenizer): out_string = ' '.join(tokens).replace(' ##', '').strip() return out_string - def add_special_tokens_single_sequence(self, token_ids): + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ - Adds special tokens to the a sequence for sequence classification tasks. - A BERT sequence has the following format: [CLS] X [SEP] + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A BERT sequence has the following format: + single sequence: [CLS] X [SEP] + pair of sequences: [CLS] A [SEP] B [SEP] """ - return [self.cls_token_id] + token_ids + [self.sep_token_id] - - def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1): - """ - Adds special tokens to a sequence pair for sequence classification tasks. - A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP] - """ - sep = [self.sep_token_id] + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] cls = [self.cls_token_id] - + sep = [self.sep_token_id] return cls + token_ids_0 + sep + token_ids_1 + sep - def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, special_tokens_present=False): + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. @@ -212,30 +209,37 @@ class BertTokenizer(PreTrainedTokenizer): Args: token_ids_0: list of ids (must not contain special tokens) token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids - for sequence pairs + for sequence pairs + already_has_special_tokens: (default False) Set to True if the token list is already formated with + special tokens for the model Returns: A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. """ - if special_tokens_present: - return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0)) + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError("You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model.") + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) - if token_ids_1: - return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0] - else: - return [0] + ([1] * len(token_ids_0)) + [0] + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] - def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1): + def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence pair mask has the following format: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 | first sequence | second sequence + + if token_ids_1 is None, only returns the first portion of the mask (0's). """ sep = [self.sep_token_id] cls = [self.cls_token_id] - + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] def save_vocabulary(self, vocab_path): diff --git a/transformers/tokenization_roberta.py b/transformers/tokenization_roberta.py index ceb1245169..9cc8a9af6e 100644 --- a/transformers/tokenization_roberta.py +++ b/transformers/tokenization_roberta.py @@ -84,23 +84,21 @@ class RobertaTokenizer(GPT2Tokenizer): self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens - def add_special_tokens_single_sequence(self, token_ids): + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ - Adds special tokens to a sequence for sequence classification tasks. - A RoBERTa sequence has the following format: X + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A RoBERTa sequence has the following format: + single sequence: X + pair of sequences: A B """ - return [self.cls_token_id] + token_ids + [self.sep_token_id] - - def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1): - """ - Adds special tokens to a sequence pair for sequence classification tasks. - A RoBERTa sequence pair has the following format: A B - """ - sep = [self.sep_token_id] + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] cls = [self.cls_token_id] + sep = [self.sep_token_id] return cls + token_ids_0 + sep + sep + token_ids_1 + sep - def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, special_tokens_present=False): + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. @@ -108,28 +106,35 @@ class RobertaTokenizer(GPT2Tokenizer): Args: token_ids_0: list of ids (must not contain special tokens) token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids - for sequence pairs + for sequence pairs + already_has_special_tokens: (default False) Set to True if the token list is already formated with + special tokens for the model Returns: A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. """ + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError("You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model.") + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) - if special_tokens_present: - return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0)) + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] - if token_ids_1: - return [0] + ([1] * len(token_ids_0)) + [0, 0] + ([1] * len(token_ids_1)) + [0] - else: - return [0] + ([1] * len(token_ids_0)) + [0] - - def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1): + def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. A RoBERTa sequence pair mask has the following format: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 | first sequence | second sequence + + if token_ids_1 is None, only returns the first portion of the mask (0's). """ sep = [self.sep_token_id] cls = [self.cls_token_id] - return len(cls + token_ids_0 + sep + sep) * [0] + len(token_ids_1 + sep) * [1] \ No newline at end of file + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep) * [0] + len(token_ids_1 + sep) * [1] diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index 2f25e95f4b..ce5811b96f 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -538,15 +538,9 @@ class PreTrainedTokenizer(object): Returns: Number of tokens added to sequences """ - - if pair: - initial_tokens_len = len(self.encode("This is a sequence") + self.encode("This is another")) - final_tokens_len = len(self.encode("This is a sequence", "This is another", add_special_tokens=True)) - else: - initial_tokens_len = len(self.encode("This is a sequence")) - final_tokens_len = len(self.encode("This is a sequence", add_special_tokens=True)) - - return final_tokens_len - initial_tokens_len + token_ids_0 = [] + token_ids_1 = [] + return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None)) def add_special_tokens(self, special_tokens_dict): """ @@ -795,7 +789,7 @@ class PreTrainedTokenizer(object): return_tensors=return_tensors) def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0, - truncate_first_sequence=True, truncate_both_sequences=True, return_tensors=None): + truncation_strategy='longest_first', return_tensors=None): """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It adds special tokens, truncates @@ -812,6 +806,12 @@ class PreTrainedTokenizer(object): to their model. stride: window stride for overflowing tokens. Can be useful for edge effect removal when using sequential list of inputs. + truncation_strategy: string selected in the following options: + - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length + starting from the longest one at each token (when there is a pair of input sequences) + - 'only_first': Only truncate the first sequence + - 'only_second': Only truncate the second sequence + - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) truncate_first_sequence: if set to `True` and an optional second list of input ids is provided, alongside a specified `max_length`, will truncate the first sequence if the total size is superior than the specified `max_length`. If set to `False`, will truncate the second sequence instead. @@ -840,37 +840,17 @@ class PreTrainedTokenizer(object): len_pair_ids = len(pair_ids) if pair else 0 encoded_inputs = {} - if max_length: - n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0 - - if n_added_tokens + len_ids + len_pair_ids > max_length: - if truncate_both_sequences: - tokens_a, tokens_b = self._truncate_seq_pair( - copy.deepcopy(ids), - copy.deepcopy(pair_ids), - max_length=max_length - n_added_tokens - ) - truncated_tokens = ids[- (len_ids - len(tokens_a)):] + pair_ids[- (len_pair_ids - len(tokens_b)):] - encoded_inputs["num_truncated_tokens"] = len(truncated_tokens) - ids = tokens_a - pair_ids = tokens_b - elif pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length: - logger.warning( - "You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length." - "This pair of sequences will not be truncated.") - elif truncate_first_sequence or not pair: - encoded_inputs["overflowing_tokens"] = ids[max_length - len_pair_ids - n_added_tokens - stride:] - ids = ids[:max_length - len_pair_ids - n_added_tokens] - elif not truncate_first_sequence and pair: - encoded_inputs["overflowing_tokens"] = pair_ids[max_length - len_ids - n_added_tokens - stride:] - pair_ids = pair_ids[:max_length - len_ids - n_added_tokens] - else: - logger.warning( - "Cannot truncate second sequence as it is not provided. No truncation.") + total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0) + if max_length and total_len > max_length: + ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids, + num_tokens_to_remove=total_len-max_length, + truncation_strategy=truncation_strategy) + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length if add_special_tokens: - sequence = self.add_special_tokens_sequence_pair(ids, pair_ids) if pair else self.add_special_tokens_single_sequence(ids) - token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) if pair else [0] * len(sequence) + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) else: sequence = ids + pair_ids if pair else ids @@ -895,39 +875,76 @@ class PreTrainedTokenizer(object): return encoded_inputs - def _truncate_seq_pair(self, tokens_a, tokens_b, max_length): - """Truncates a sequence pair in place to the maximum length.""" + def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first'): + """Truncates a sequence pair in place to the maximum length. + truncation_strategy: string selected in the following options: + - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length + starting from the longest one at each token (when there is a pair of input sequences). + Overflowing tokens only contains overflow from the first sequence. + - 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove. + - 'only_second': Only truncate the second sequence + - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) + """ + if num_tokens_to_remove <= 0: + return ids, pair_ids, [] - # This is a simple heuristic which will always truncate the longer sequence - # one token at a time. This makes more sense than truncating an equal percent - # of tokens from each, since if one sequence is very short then each token - # that's truncated likely contains more information than a longer sequence. + if truncation_strategy == 'longest_first': + overflowing_tokens = [] + for _ in range(num_tokens_to_remove): + if pair_ids is None or len(ids) > len(pair_ids): + overflowing_tokens.append(ids[-1]) + ids = ids[:-1] + else: + pair_ids = pair_ids[:-1] + elif truncation_strategy == 'only_first': + assert len(ids) > num_tokens_to_remove + overflowing_tokens = ids[-num_tokens_to_remove:] + ids = ids[:-num_tokens_to_remove] + elif truncation_strategy == 'only_second': + assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove + overflowing_tokens = pair_ids[-num_tokens_to_remove:] + pair_ids = pair_ids[:-num_tokens_to_remove] + elif truncation_strategy == 'do_not_truncate': + raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.") + else: + raise ValueError("Truncation_strategy should be selected in ['longest_first', 'only_first', 'only_second', 'do_not_truncate']") + return (ids, pair_ids, overflowing_tokens) - # However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger - # length or only pop from context - while True: - total_length = len(tokens_a) + len(tokens_b) - if total_length <= max_length: - return (tokens_a, tokens_b) - if len(tokens_a) > len(tokens_b): - tokens_a.pop() - else: - tokens_b.pop() - - def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1): + def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): logger.warning("This tokenizer does not make use of special tokens.") + if token_ids_1 is None: + return len(token_ids_0) * [0] return [0] * len(token_ids_0) + [1] * len(token_ids_1) - def add_special_tokens_single_sequence(self, token_ids): - logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.") - return token_ids - - def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1): - logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.") + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A RoBERTa sequence has the following format: + single sequence: X + pair of sequences: A B + """ + logger.warning("This tokenizer does not make use of special tokens. Input is returned with no modification.") + if token_ids_1 is None: + return token_ids_0 return token_ids_0 + token_ids_1 - def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, special_tokens_present=False): - return [1] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. + + Args: + token_ids_0: list of ids (must not contain special tokens) + token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids + for sequence pairs + already_has_special_tokens: (default False) Set to True if the token list is already formated with + special tokens for the model + + Returns: + A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. + """ + return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) def convert_ids_to_tokens(self, ids, skip_special_tokens=False): """ Converts a single index or a sequence of indices (integers) in a token " diff --git a/transformers/tokenization_xlm.py b/transformers/tokenization_xlm.py index 817c6d8c44..d09ce6b9dc 100644 --- a/transformers/tokenization_xlm.py +++ b/transformers/tokenization_xlm.py @@ -754,23 +754,21 @@ class XLMTokenizer(PreTrainedTokenizer): out_string = ''.join(tokens).replace('', ' ').strip() return out_string - def add_special_tokens_single_sequence(self, token_ids): + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ - Adds special tokens to a sequence for sequence classification tasks. - An XLM sequence has the following format: [CLS] X [SEP] - """ - return [self.cls_token_id] + token_ids + [self.sep_token_id] - - def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1): - """ - Adds special tokens to a sequence pair for sequence classification tasks. - An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP] + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A RoBERTa sequence has the following format: + single sequence: X + pair of sequences: A B """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] sep = [self.sep_token_id] cls = [self.cls_token_id] return cls + token_ids_0 + sep + token_ids_1 + sep - def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, special_tokens_present=False): + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. @@ -778,30 +776,37 @@ class XLMTokenizer(PreTrainedTokenizer): Args: token_ids_0: list of ids (must not contain special tokens) token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids - for sequence pairs + for sequence pairs + already_has_special_tokens: (default False) Set to True if the token list is already formated with + special tokens for the model Returns: A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. """ - if special_tokens_present: - return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0)) + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError("You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model.") + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) - if token_ids_1: - return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0] - else: - return [0] + ([1] * len(token_ids_0)) + [0] + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] - def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1): + def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An XLM sequence pair mask has the following format: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 | first sequence | second sequence + + if token_ids_1 is None, only returns the first portion of the mask (0's). """ sep = [self.sep_token_id] cls = [self.cls_token_id] - + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] def save_vocabulary(self, save_directory): diff --git a/transformers/tokenization_xlnet.py b/transformers/tokenization_xlnet.py index 37f975abef..deae8de336 100644 --- a/transformers/tokenization_xlnet.py +++ b/transformers/tokenization_xlnet.py @@ -181,26 +181,21 @@ class XLNetTokenizer(PreTrainedTokenizer): out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() return out_string - def add_special_tokens_single_sequence(self, token_ids): + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ - Adds special tokens to a sequence for sequence classification tasks. - An XLNet sequence has the following format: X [SEP][CLS] + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A RoBERTa sequence has the following format: + single sequence: X + pair of sequences: A B """ sep = [self.sep_token_id] cls = [self.cls_token_id] - return token_ids + sep + cls - - def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1): - """ - Adds special tokens to a sequence pair for sequence classification tasks. - An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS] - """ - - sep = [self.sep_token_id] - cls = [self.cls_token_id] + if token_ids_1 is None: + return token_ids_0 + sep + cls return token_ids_0 + sep + token_ids_1 + sep + cls - def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, special_tokens_present=False): + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. @@ -208,31 +203,39 @@ class XLNetTokenizer(PreTrainedTokenizer): Args: token_ids_0: list of ids (must not contain special tokens) token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids - for sequence pairs + for sequence pairs + already_has_special_tokens: (default False) Set to True if the token list is already formated with + special tokens for the model Returns: A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. """ - if special_tokens_present: - return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0)) + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError("You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model.") + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) - if token_ids_1: - return ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0, 0] - else: - return ([1] * len(token_ids_0)) + [0, 0] + if token_ids_1 is not None: + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1, 1] + return ([0] * len(token_ids_0)) + [1, 1] - def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1): + def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence pair mask has the following format: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 2 | first sequence | second sequence | CLS segment ID + + if token_ids_1 is None, only returns the first portion of the mask (0's). """ sep = [self.sep_token_id] cls = [self.cls_token_id] cls_segment_id = [2] + if token_ids_1 is None: + return len(token_ids_0 + sep + cls) * [0] return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id def save_vocabulary(self, save_directory):