Data collator for token classification (#8274)

* Add DataCollatorForTokenClassification and clean tests

* Make quality
This commit is contained in:
Sylvain Gugger 2020-11-03 16:33:27 -05:00 committed by GitHub
parent 6a064447f2
commit 7f556d2e39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 212 additions and 108 deletions

View File

@ -284,6 +284,7 @@ if is_torch_available():
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
default_data_collator,

View File

@ -114,6 +114,67 @@ class DataCollatorWithPadding:
return batch
@dataclass
class DataCollatorForTokenClassification:
"""
Data collator that will dynamically pad the inputs received, as well as the labels.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
return_tensors="pt" if labels is None else None,
)
if labels is None:
return batch
sequence_length = torch.tensor(batch["input_ids"]).shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
else:
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
return batch
@dataclass
class DataCollatorForLanguageModeling:
"""

View File

@ -45,6 +45,15 @@ class DataCollatorForSOP:
requires_pytorch(self)
class DataCollatorForTokenClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class DataCollatorForWholeWordMask:
def __init__(self, *args, **kwargs):
requires_pytorch(self)

View File

@ -1,7 +1,10 @@
import os
import shutil
import tempfile
import unittest
from transformers import AutoTokenizer, is_torch_available
from transformers.testing_utils import require_torch, slow
from transformers import BertTokenizer, is_torch_available, set_seed
from transformers.testing_utils import require_torch
if is_torch_available():
@ -12,22 +15,25 @@ if is_torch_available():
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
GlueDataset,
GlueDataTrainingArguments,
LineByLineTextDataset,
LineByLineWithSOPTextDataset,
TextDataset,
TextDatasetForNextSentencePrediction,
DataCollatorForTokenClassification,
DataCollatorWithPadding,
default_data_collator,
)
PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
PATH_SAMPLE_TEXT_DIR = "./tests/fixtures/tests_samples/wiki_text"
@require_torch
class DataCollatorIntegrationTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
self.vocab_file = os.path.join(self.tmpdirname, "vocab.txt")
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_default_with_dict(self):
features = [{"label": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
@ -57,6 +63,17 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
def test_default_classification_and_regression(self):
data_collator = default_data_collator
features = [{"input_ids": [0, 1, 2, 3, 4], "label": i} for i in range(4)]
batch = data_collator(features)
self.assertEqual(batch["labels"].dtype, torch.long)
features = [{"input_ids": [0, 1, 2, 3, 4], "label": float(i)} for i in range(4)]
batch = data_collator(features)
self.assertEqual(batch["labels"].dtype, torch.float)
def test_default_with_no_labels(self):
features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
@ -69,128 +86,144 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertTrue("labels" not in batch)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
@slow
def test_default_classification(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
data_args = GlueDataTrainingArguments(
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = default_data_collator
batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.long)
def test_data_collator_with_padding(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
@slow
def test_default_regression(self):
MODEL_ID = "distilroberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
data_args = GlueDataTrainingArguments(
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = default_data_collator
batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.float)
data_collator = DataCollatorWithPadding(tokenizer)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=10)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
def test_data_collator_for_token_classification(self):
tokenizer = BertTokenizer(self.vocab_file)
features = [
{"input_ids": [0, 1, 2], "labels": [0, 1, 2]},
{"input_ids": [0, 1, 2, 3, 4, 5], "labels": [0, 1, 2, 3, 4, 5]},
]
data_collator = DataCollatorForTokenClassification(tokenizer)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-100] * 3)
data_collator = DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
self.assertEqual(batch["labels"].shape, torch.Size([2, 10]))
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
self.assertEqual(batch["labels"].shape, torch.Size([2, 8]))
data_collator = DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1)
batch = data_collator(features)
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)
def test_data_collator_for_language_modeling(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
@slow
def test_lm_tokenizer_without_padding(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# ^ causal lm
batch = data_collator(no_pad_features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(pad_features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
tokenizer._pad_token = None
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
with self.assertRaises(ValueError):
# Expect error due to padding token missing on gpt2:
data_collator(examples)
# Expect error due to padding token missing
data_collator(pad_features)
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
@slow
def test_lm_tokenizer_with_padding(self):
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
set_seed(42) # For reproducibility
tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForLanguageModeling(tokenizer)
# ^ masked lm
batch = data_collator(no_pad_features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
self.assertTrue(torch.any(masked_tokens))
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
batch = data_collator(pad_features)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
self.assertTrue(torch.any(masked_tokens))
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
@slow
def test_plm(self):
tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
data_collator = DataCollatorForPermutationLanguageModeling(tokenizer)
# ^ permutation lm
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
batch = data_collator(pad_features)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 112)))
self.assertEqual(batch["perm_mask"].shape, torch.Size((31, 112, 112)))
self.assertEqual(batch["target_mapping"].shape, torch.Size((31, 112, 112)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 112)))
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 10, 10)))
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 10, 10)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
batch = data_collator(no_pad_features)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 512, 512)))
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 512, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 10, 10)))
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 10, 10)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
example = [torch.randint(5, [5])]
with self.assertRaises(ValueError):
# Expect error due to odd sequence length
data_collator(example)
@slow
def test_nsp(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer = BertTokenizer(self.vocab_file)
features = [{"tokens_a": [0, 1, 2, 3, 4], "tokens_b": [0, 1, 2, 3, 4], "is_random_next": i} for i in range(2)]
data_collator = DataCollatorForNextSentencePrediction(tokenizer)
batch = data_collator(features)
dataset = TextDatasetForNextSentencePrediction(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
# Since there are randomly generated false samples, the total number of samples is not fixed.
total_samples = batch["input_ids"].shape[0]
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))
@slow
def test_sop(self):
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
tokenizer = BertTokenizer(self.vocab_file)
features = [
{
"input_ids": torch.tensor([0, 1, 2, 3, 4]),
"token_type_ids": torch.tensor([0, 1, 2, 3, 4]),
"sentence_order_label": torch.tensor(i),
}
for i in range(2)
]
data_collator = DataCollatorForSOP(tokenizer)
batch = data_collator(features)
dataset = LineByLineWithSOPTextDataset(tokenizer, file_dir=PATH_SAMPLE_TEXT_DIR, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
# Since there are randomly generated false samples, the total number of samples is not fixed.
total_samples = batch["input_ids"].shape[0]
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((total_samples,)))
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))