From 99a90e43d421369357815b21771f5211c2528667 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 24 Sep 2019 17:16:46 +0200 Subject: [PATCH] update data processors __init__ --- pytorch_transformers/__init__.py | 2 +- pytorch_transformers/data/__init__.py | 6 +- .../data/processors/__init__.py | 3 +- pytorch_transformers/data/processors/glue.py | 78 +++++++++---------- 4 files changed, 45 insertions(+), 44 deletions(-) diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index 93de6a982b..130a32885a 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -76,7 +76,7 @@ from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CAC from .data import (is_sklearn_available, InputExample, InputFeatures, DataProcessor, - glue_output_modes, glue_convert_examples_to_features, glue_processors) + glue_output_modes, glue_convert_examples_to_features, glue_processors, glue_tasks_num_labels) if is_sklearn_available(): from .data import glue_compute_metrics diff --git a/pytorch_transformers/data/__init__.py b/pytorch_transformers/data/__init__.py index 4522b802ab..e910d6da2e 100644 --- a/pytorch_transformers/data/__init__.py +++ b/pytorch_transformers/data/__init__.py @@ -1,6 +1,6 @@ -from .processors import (InputExample, InputFeatures, DataProcessor, - glue_output_modes, glue_convert_examples_to_features, glue_processors) -from .metrics import is_sklearn_available +from .processors import InputExample, InputFeatures, DataProcessor +from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features +from .metrics import is_sklearn_available if is_sklearn_available(): from .metrics import glue_compute_metrics diff --git a/pytorch_transformers/data/processors/__init__.py b/pytorch_transformers/data/processors/__init__.py index 1a442bf839..af38c54beb 100644 --- a/pytorch_transformers/data/processors/__init__.py +++ b/pytorch_transformers/data/processors/__init__.py @@ -1,2 +1,3 @@ from .utils import InputExample, InputFeatures, DataProcessor -from .glue import output_modes, processors, convert_examples_to_glue_features +from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features + diff --git a/pytorch_transformers/data/processors/glue.py b/pytorch_transformers/data/processors/glue.py index e1376816de..a7d56491e2 100644 --- a/pytorch_transformers/data/processors/glue.py +++ b/pytorch_transformers/data/processors/glue.py @@ -22,45 +22,7 @@ from .utils import DataProcessor, InputExample, InputFeatures logger = logging.getLogger(__name__) -GLUE_TASKS_NUM_LABELS = { - "cola": 2, - "mnli": 3, - "mrpc": 2, - "sst-2": 2, - "sts-b": 1, - "qqp": 2, - "qnli": 2, - "rte": 2, - "wnli": 2, -} - -processors = { - "cola": ColaProcessor, - "mnli": MnliProcessor, - "mnli-mm": MnliMismatchedProcessor, - "mrpc": MrpcProcessor, - "sst-2": Sst2Processor, - "sts-b": StsbProcessor, - "qqp": QqpProcessor, - "qnli": QnliProcessor, - "rte": RteProcessor, - "wnli": WnliProcessor, -} - -output_modes = { - "cola": "classification", - "mnli": "classification", - "mnli-mm": "classification", - "mrpc": "classification", - "sst-2": "classification", - "sts-b": "regression", - "qqp": "classification", - "qnli": "classification", - "rte": "classification", - "wnli": "classification", -} - -def convert_examples_to_glue_features(examples, label_list, max_seq_length, +def glue_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, output_mode, pad_on_left=False, pad_token=0, @@ -427,3 +389,41 @@ class WnliProcessor(DataProcessor): examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples + +glue_tasks_num_labels = { + "cola": 2, + "mnli": 3, + "mrpc": 2, + "sst-2": 2, + "sts-b": 1, + "qqp": 2, + "qnli": 2, + "rte": 2, + "wnli": 2, +} + +glue_processors = { + "cola": ColaProcessor, + "mnli": MnliProcessor, + "mnli-mm": MnliMismatchedProcessor, + "mrpc": MrpcProcessor, + "sst-2": Sst2Processor, + "sts-b": StsbProcessor, + "qqp": QqpProcessor, + "qnli": QnliProcessor, + "rte": RteProcessor, + "wnli": WnliProcessor, +} + +glue_output_modes = { + "cola": "classification", + "mnli": "classification", + "mnli-mm": "classification", + "mrpc": "classification", + "sst-2": "classification", + "sts-b": "regression", + "qqp": "classification", + "qnli": "classification", + "rte": "classification", + "wnli": "classification", +}