update internal glue processors

This commit is contained in:
thomwolf 2019-09-25 12:00:50 +02:00
parent 0f091062d4
commit f71758f7a4
3 changed files with 108 additions and 36 deletions

View File

@ -278,7 +278,11 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1]
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
features = convert_examples_to_features(examples,
label_list,
args.max_seq_length,
tokenizer,
output_mode,
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
@ -292,14 +296,14 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
if output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
elif output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
return dataset

View File

@ -19,11 +19,18 @@ import logging
import os
from .utils import DataProcessor, InputExample, InputFeatures
from ...file_utils import is_tf_available
if is_tf_available():
import tensorflow as tf
logger = logging.getLogger(__name__)
def glue_convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer, output_mode,
def glue_convert_examples_to_features(examples, tokenizer,
max_length=512,
task=None,
label_list=None,
output_mode=None,
pad_on_left=False,
pad_token=0,
pad_token_segment_id=0,
@ -31,46 +38,63 @@ def glue_convert_examples_to_features(examples, label_list, max_seq_length,
"""
Loads a data file into a list of `InputBatch`s
"""
is_tf_dataset = False
if is_tf_available() and isinstance(examples, tf.data.Dataset):
is_tf_dataset = True
if task is not None:
processor = glue_processors[task]()
if label_list is None:
label_list = processor.get_labels()
logger.info("Using label list %s for task %s" % (label_list, task))
if output_mode is None:
output_mode = glue_output_modes[task]
logger.info("Using output mode %s for task %s" % (output_mode, task))
label_map = {label: i for i, label in enumerate(label_list)}
features = []
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
logger.info("Writing example %d" % (ex_index))
if is_tf_dataset:
example = InputExample(example['idx'].numpy(),
example['sentence1'].numpy().decode('utf-8'),
example['sentence2'].numpy().decode('utf-8'),
str(example['label'].numpy()))
inputs = tokenizer.encode_plus(
example.text_a,
example.text_b,
add_special_tokens=True,
max_length=max_seq_length,
truncate_first_sequence=True # We're truncating the first sequence as a priority
max_length=max_length,
truncate_first_sequence=True # We're truncating the first sequence in priority
)
input_ids, segment_ids = inputs["input_ids"], inputs["token_type_ids"]
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length.
padding_length = max_seq_length - len(input_ids)
padding_length = max_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
else:
input_ids = input_ids + ([pad_token] * padding_length)
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length)
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length)
if output_mode == "classification":
label_id = label_map[example.label]
label = label_map[example.label]
elif output_mode == "regression":
label_id = float(example.label)
label = float(example.label)
else:
raise KeyError(output_mode)
@ -78,15 +102,34 @@ def glue_convert_examples_to_features(examples, label_list, max_seq_length,
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logger.info("label: %s (id = %d)" % (example.label, label_id))
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
logger.info("label: %s (id = %d)" % (example.label, label))
features.append(
InputFeatures(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id))
attention_mask=attention_mask,
token_type_ids=token_type_ids,
label=label))
if is_tf_available() and is_tf_dataset:
def gen():
for ex in features:
yield ({'input_ids': ex.input_ids,
'attention_mask': ex.attention_mask,
'token_type_ids': ex.token_type_ids},
ex.label)
return tf.data.Dataset.from_generator(gen,
({'input_ids': tf.int32,
'attention_mask': tf.int32,
'token_type_ids': tf.int32},
tf.int64),
({'input_ids': tf.TensorShape([None]),
'attention_mask': tf.TensorShape([None]),
'token_type_ids': tf.TensorShape([None])},
tf.TensorShape([])))
return features

View File

@ -16,6 +16,7 @@
import csv
import sys
import copy
class InputExample(object):
"""A single training/test example for simple sequence classification."""
@ -36,15 +37,39 @@ class InputExample(object):
self.text_b = text_b
self.label = label
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_mask, segment_ids, label_id):
def __init__(self, input_ids, attention_mask, token_type_ids, label):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.attention_mask = attention_mask
self.token_type_ids = token_type_ids
self.label = label
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class DataProcessor(object):