diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7285a41eb2..ee5da43999 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -284,6 +284,7 @@ if is_torch_available(): DataCollatorForNextSentencePrediction, DataCollatorForPermutationLanguageModeling, DataCollatorForSOP, + DataCollatorForTokenClassification, DataCollatorForWholeWordMask, DataCollatorWithPadding, default_data_collator, diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 080aadca2b..7eccad31f9 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -114,6 +114,67 @@ class DataCollatorWithPadding: return batch +@dataclass +class DataCollatorForTokenClassification: + """ + Data collator that will dynamically pad the inputs received, as well as the labels. + + Args: + tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): + The tokenizer used for encoding the data. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + label_pad_token_id (:obj:`int`, `optional`, defaults to -100): + The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). + """ + + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + + def __call__(self, features): + label_name = "label" if "label" in features[0].keys() else "labels" + labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None + batch = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + # Conversion to tensors will fail if we have labels as they are not of the same length yet. + return_tensors="pt" if labels is None else None, + ) + + if labels is None: + return batch + + sequence_length = torch.tensor(batch["input_ids"]).shape[1] + padding_side = self.tokenizer.padding_side + if padding_side == "right": + batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels] + else: + batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels] + + batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} + return batch + + @dataclass class DataCollatorForLanguageModeling: """ diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 9109c1a25d..c6d70a5361 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -45,6 +45,15 @@ class DataCollatorForSOP: requires_pytorch(self) +class DataCollatorForTokenClassification: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + class DataCollatorForWholeWordMask: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/tests/test_data_collator.py b/tests/test_data_collator.py index ff45e87d6c..d46a96589c 100644 --- a/tests/test_data_collator.py +++ b/tests/test_data_collator.py @@ -1,7 +1,10 @@ +import os +import shutil +import tempfile import unittest -from transformers import AutoTokenizer, is_torch_available -from transformers.testing_utils import require_torch, slow +from transformers import BertTokenizer, is_torch_available, set_seed +from transformers.testing_utils import require_torch if is_torch_available(): @@ -12,22 +15,25 @@ if is_torch_available(): DataCollatorForNextSentencePrediction, DataCollatorForPermutationLanguageModeling, DataCollatorForSOP, - GlueDataset, - GlueDataTrainingArguments, - LineByLineTextDataset, - LineByLineWithSOPTextDataset, - TextDataset, - TextDatasetForNextSentencePrediction, + DataCollatorForTokenClassification, + DataCollatorWithPadding, default_data_collator, ) -PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt" -PATH_SAMPLE_TEXT_DIR = "./tests/fixtures/tests_samples/wiki_text" - - @require_torch class DataCollatorIntegrationTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] + self.vocab_file = os.path.join(self.tmpdirname, "vocab.txt") + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + def test_default_with_dict(self): features = [{"label": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)] batch = default_data_collator(features) @@ -57,6 +63,17 @@ class DataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["labels"].dtype, torch.long) self.assertEqual(batch["inputs"].shape, torch.Size([8, 10])) + def test_default_classification_and_regression(self): + data_collator = default_data_collator + + features = [{"input_ids": [0, 1, 2, 3, 4], "label": i} for i in range(4)] + batch = data_collator(features) + self.assertEqual(batch["labels"].dtype, torch.long) + + features = [{"input_ids": [0, 1, 2, 3, 4], "label": float(i)} for i in range(4)] + batch = data_collator(features) + self.assertEqual(batch["labels"].dtype, torch.float) + def test_default_with_no_labels(self): features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)] batch = default_data_collator(features) @@ -69,128 +86,144 @@ class DataCollatorIntegrationTest(unittest.TestCase): self.assertTrue("labels" not in batch) self.assertEqual(batch["inputs"].shape, torch.Size([8, 6])) - @slow - def test_default_classification(self): - MODEL_ID = "bert-base-cased-finetuned-mrpc" - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - data_args = GlueDataTrainingArguments( - task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True - ) - dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") - data_collator = default_data_collator - batch = data_collator(dataset.features) - self.assertEqual(batch["labels"].dtype, torch.long) + def test_data_collator_with_padding(self): + tokenizer = BertTokenizer(self.vocab_file) + features = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}] - @slow - def test_default_regression(self): - MODEL_ID = "distilroberta-base" - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - data_args = GlueDataTrainingArguments( - task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True - ) - dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") - data_collator = default_data_collator - batch = data_collator(dataset.features) - self.assertEqual(batch["labels"].dtype, torch.float) + data_collator = DataCollatorWithPadding(tokenizer) + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6])) + self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3) + + data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=10) + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10])) + + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8])) + + def test_data_collator_for_token_classification(self): + tokenizer = BertTokenizer(self.vocab_file) + features = [ + {"input_ids": [0, 1, 2], "labels": [0, 1, 2]}, + {"input_ids": [0, 1, 2, 3, 4, 5], "labels": [0, 1, 2, 3, 4, 5]}, + ] + + data_collator = DataCollatorForTokenClassification(tokenizer) + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6])) + self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3) + self.assertEqual(batch["labels"].shape, torch.Size([2, 6])) + self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-100] * 3) + + data_collator = DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10) + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10])) + self.assertEqual(batch["labels"].shape, torch.Size([2, 10])) + + data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8) + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8])) + self.assertEqual(batch["labels"].shape, torch.Size([2, 8])) + + data_collator = DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1) + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6])) + self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3) + self.assertEqual(batch["labels"].shape, torch.Size([2, 6])) + self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3) + + def test_data_collator_for_language_modeling(self): + tokenizer = BertTokenizer(self.vocab_file) + no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] + pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}] - @slow - def test_lm_tokenizer_without_padding(self): - tokenizer = AutoTokenizer.from_pretrained("gpt2") data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) - # ^ causal lm + batch = data_collator(no_pad_features) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) - dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512) - examples = [dataset[i] for i in range(len(dataset))] + batch = data_collator(pad_features) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) + + tokenizer._pad_token = None + data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) with self.assertRaises(ValueError): - # Expect error due to padding token missing on gpt2: - data_collator(examples) + # Expect error due to padding token missing + data_collator(pad_features) - dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True) - examples = [dataset[i] for i in range(len(dataset))] - batch = data_collator(examples) - self.assertIsInstance(batch, dict) - self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) - self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) - - @slow - def test_lm_tokenizer_with_padding(self): - tokenizer = AutoTokenizer.from_pretrained("distilroberta-base") + set_seed(42) # For reproducibility + tokenizer = BertTokenizer(self.vocab_file) data_collator = DataCollatorForLanguageModeling(tokenizer) - # ^ masked lm + batch = data_collator(no_pad_features) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) - dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512) - examples = [dataset[i] for i in range(len(dataset))] - batch = data_collator(examples) - self.assertIsInstance(batch, dict) - self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107))) - self.assertEqual(batch["labels"].shape, torch.Size((31, 107))) + masked_tokens = batch["input_ids"] == tokenizer.mask_token_id + self.assertTrue(torch.any(masked_tokens)) + self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist())) - dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True) - examples = [dataset[i] for i in range(len(dataset))] - batch = data_collator(examples) - self.assertIsInstance(batch, dict) - self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) - self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) + batch = data_collator(pad_features) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) + + masked_tokens = batch["input_ids"] == tokenizer.mask_token_id + self.assertTrue(torch.any(masked_tokens)) + self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist())) - @slow def test_plm(self): - tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased") + tokenizer = BertTokenizer(self.vocab_file) + no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] + pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}] + data_collator = DataCollatorForPermutationLanguageModeling(tokenizer) - # ^ permutation lm - dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512) - examples = [dataset[i] for i in range(len(dataset))] - batch = data_collator(examples) + batch = data_collator(pad_features) self.assertIsInstance(batch, dict) - self.assertEqual(batch["input_ids"].shape, torch.Size((31, 112))) - self.assertEqual(batch["perm_mask"].shape, torch.Size((31, 112, 112))) - self.assertEqual(batch["target_mapping"].shape, torch.Size((31, 112, 112))) - self.assertEqual(batch["labels"].shape, torch.Size((31, 112))) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) + self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 10, 10))) + self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 10, 10))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) - dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True) - examples = [dataset[i] for i in range(len(dataset))] - batch = data_collator(examples) + batch = data_collator(no_pad_features) self.assertIsInstance(batch, dict) - self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) - self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 512, 512))) - self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 512, 512))) - self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) + self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 10, 10))) + self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 10, 10))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) example = [torch.randint(5, [5])] with self.assertRaises(ValueError): # Expect error due to odd sequence length data_collator(example) - @slow def test_nsp(self): - tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + tokenizer = BertTokenizer(self.vocab_file) + features = [{"tokens_a": [0, 1, 2, 3, 4], "tokens_b": [0, 1, 2, 3, 4], "is_random_next": i} for i in range(2)] data_collator = DataCollatorForNextSentencePrediction(tokenizer) + batch = data_collator(features) - dataset = TextDatasetForNextSentencePrediction(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512) - examples = [dataset[i] for i in range(len(dataset))] - batch = data_collator(examples) - self.assertIsInstance(batch, dict) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) + self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 512))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) + self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,))) - # Since there are randomly generated false samples, the total number of samples is not fixed. - total_samples = batch["input_ids"].shape[0] - self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512))) - self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512))) - self.assertEqual(batch["labels"].shape, torch.Size((total_samples, 512))) - self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,))) - - @slow def test_sop(self): - tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") + tokenizer = BertTokenizer(self.vocab_file) + features = [ + { + "input_ids": torch.tensor([0, 1, 2, 3, 4]), + "token_type_ids": torch.tensor([0, 1, 2, 3, 4]), + "sentence_order_label": torch.tensor(i), + } + for i in range(2) + ] data_collator = DataCollatorForSOP(tokenizer) + batch = data_collator(features) - dataset = LineByLineWithSOPTextDataset(tokenizer, file_dir=PATH_SAMPLE_TEXT_DIR, block_size=512) - examples = [dataset[i] for i in range(len(dataset))] - batch = data_collator(examples) - self.assertIsInstance(batch, dict) - - # Since there are randomly generated false samples, the total number of samples is not fixed. - total_samples = batch["input_ids"].shape[0] - self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512))) - self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512))) - self.assertEqual(batch["labels"].shape, torch.Size((total_samples, 512))) - self.assertEqual(batch["sentence_order_label"].shape, torch.Size((total_samples,))) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5))) + self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 5))) + self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))