Merge pull request #1355 from agrinh/master
Fix tensorflow_dataset glue support
This commit is contained in:
commit
6a17b3c51b
|
@ -79,10 +79,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
|
|||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d" % (ex_index))
|
||||
if is_tf_dataset:
|
||||
example = InputExample(example['idx'].numpy(),
|
||||
example['sentence1'].numpy().decode('utf-8'),
|
||||
example['sentence2'].numpy().decode('utf-8'),
|
||||
str(example['label'].numpy()))
|
||||
example = processor.get_example_from_tensor_dict(example)
|
||||
|
||||
inputs = tokenizer.encode_plus(
|
||||
example.text_a,
|
||||
|
@ -157,6 +154,13 @@ def glue_convert_examples_to_features(examples, tokenizer,
|
|||
class MrpcProcessor(DataProcessor):
|
||||
"""Processor for the MRPC data set (GLUE version)."""
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
||||
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
|
||||
|
@ -190,6 +194,13 @@ class MrpcProcessor(DataProcessor):
|
|||
class MnliProcessor(DataProcessor):
|
||||
"""Processor for the MultiNLI data set (GLUE version)."""
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['premise'].numpy().decode('utf-8'),
|
||||
tensor_dict['hypothesis'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
|
@ -233,6 +244,13 @@ class MnliMismatchedProcessor(MnliProcessor):
|
|||
class ColaProcessor(DataProcessor):
|
||||
"""Processor for the CoLA data set (GLUE version)."""
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['sentence'].numpy().decode('utf-8'),
|
||||
None,
|
||||
str(tensor_dict['label'].numpy()))
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
|
@ -262,6 +280,13 @@ class ColaProcessor(DataProcessor):
|
|||
class Sst2Processor(DataProcessor):
|
||||
"""Processor for the SST-2 data set (GLUE version)."""
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['sentence'].numpy().decode('utf-8'),
|
||||
None,
|
||||
str(tensor_dict['label'].numpy()))
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
|
@ -293,6 +318,13 @@ class Sst2Processor(DataProcessor):
|
|||
class StsbProcessor(DataProcessor):
|
||||
"""Processor for the STS-B data set (GLUE version)."""
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
||||
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
|
@ -325,6 +357,13 @@ class StsbProcessor(DataProcessor):
|
|||
class QqpProcessor(DataProcessor):
|
||||
"""Processor for the QQP data set (GLUE version)."""
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['question1'].numpy().decode('utf-8'),
|
||||
tensor_dict['question2'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
|
@ -360,6 +399,13 @@ class QqpProcessor(DataProcessor):
|
|||
class QnliProcessor(DataProcessor):
|
||||
"""Processor for the QNLI data set (GLUE version)."""
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['question'].numpy().decode('utf-8'),
|
||||
tensor_dict['sentence'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
|
@ -393,6 +439,13 @@ class QnliProcessor(DataProcessor):
|
|||
class RteProcessor(DataProcessor):
|
||||
"""Processor for the RTE data set (GLUE version)."""
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
||||
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
|
@ -425,6 +478,13 @@ class RteProcessor(DataProcessor):
|
|||
class WnliProcessor(DataProcessor):
|
||||
"""Processor for the WNLI data set (GLUE version)."""
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(tensor_dict['idx'].numpy(),
|
||||
tensor_dict['sentence1'].numpy().decode('utf-8'),
|
||||
tensor_dict['sentence2'].numpy().decode('utf-8'),
|
||||
str(tensor_dict['label'].numpy()))
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
|
|
|
@ -86,6 +86,15 @@ class InputFeatures(object):
|
|||
class DataProcessor(object):
|
||||
"""Base class for data converters for sequence classification data sets."""
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""Gets an example from a dict with tensorflow tensors
|
||||
|
||||
Args:
|
||||
tensor_dict: Keys and values should match the corresponding Glue
|
||||
tensorflow_dataset examples.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the train set."""
|
||||
raise NotImplementedError()
|
||||
|
|
Loading…
Reference in New Issue