Deprecate old data/metrics functions (#8420)

This commit is contained in:
Sylvain Gugger 2020-11-09 12:10:09 -05:00 committed by GitHub
parent d4d1fbfc5a
commit 52040517b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 6 deletions

View File

@ -1,5 +1,6 @@
import os
import time
import warnings
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, Union
@ -69,6 +70,12 @@ class GlueDataset(Dataset):
mode: Union[str, Split] = Split.train,
cache_dir: Optional[str] = None,
):
warnings.warn(
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
"library. You can have a look at this example script for pointers: "
"https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_glue.py",
FutureWarning,
)
self.args = args
self.processor = glue_processors[args.task_name]()
self.output_mode = glue_output_modes[args.task_name]

View File

@ -19,7 +19,8 @@ logger = logging.get_logger(__name__)
DEPRECATION_WARNING = (
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library."
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
"library. You can have a look at this example script for pointers: {0}"
)
@ -36,7 +37,12 @@ class TextDataset(Dataset):
overwrite_cache=False,
cache_dir: Optional[str] = None,
):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
warnings.warn(
DEPRECATION_WARNING.format(
"https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_mlm.py"
),
FutureWarning,
)
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
@ -101,7 +107,12 @@ class LineByLineTextDataset(Dataset):
"""
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
warnings.warn(
DEPRECATION_WARNING.format(
"https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_mlm.py"
),
FutureWarning,
)
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
# Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the
@ -128,7 +139,12 @@ class LineByLineWithRefDataset(Dataset):
"""
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
warnings.warn(
DEPRECATION_WARNING.format(
"https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_mlm_wwm.py"
),
FutureWarning,
)
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
assert os.path.isfile(ref_path), f"Ref file path {file_path} not found"
# Here, we do not cache the features, operating under the assumption
@ -165,7 +181,12 @@ class LineByLineWithSOPTextDataset(Dataset):
"""
def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
warnings.warn(
DEPRECATION_WARNING.format(
"https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_mlm.py"
),
FutureWarning,
)
assert os.path.isdir(file_dir)
logger.info(f"Creating features from dataset file folder at {file_dir}")
self.examples = []
@ -315,7 +336,12 @@ class TextDatasetForNextSentencePrediction(Dataset):
short_seq_probability=0.1,
nsp_probability=0.5,
):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
warnings.warn(
DEPRECATION_WARNING.format(
"https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_mlm.py"
),
FutureWarning,
)
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
self.block_size = block_size - tokenizer.num_special_tokens_to_add(pair=True)

View File

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from ...file_utils import is_sklearn_available, requires_sklearn
@ -23,12 +25,21 @@ if is_sklearn_available():
from scipy.stats import pearsonr, spearmanr
DEPRECATION_WARNING = (
"This metric will be removed from the library soon, metrics should be handled with the 🤗 Datasets "
"library. You can have a look at this example script for pointers: "
"https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_glue.py",
)
def simple_accuracy(preds, labels):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
requires_sklearn(simple_accuracy)
return (preds == labels).mean()
def acc_and_f1(preds, labels):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
requires_sklearn(acc_and_f1)
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
@ -40,6 +51,7 @@ def acc_and_f1(preds, labels):
def pearson_and_spearman(preds, labels):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
requires_sklearn(pearson_and_spearman)
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
@ -51,6 +63,7 @@ def pearson_and_spearman(preds, labels):
def glue_compute_metrics(task_name, preds, labels):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
requires_sklearn(glue_compute_metrics)
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
if task_name == "cola":
@ -80,6 +93,7 @@ def glue_compute_metrics(task_name, preds, labels):
def xnli_compute_metrics(task_name, preds, labels):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
requires_sklearn(xnli_compute_metrics)
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
if task_name == "xnli":

View File

@ -16,6 +16,7 @@
""" GLUE processors and helpers """
import os
import warnings
from dataclasses import asdict
from enum import Enum
from typing import List, Optional, Union
@ -31,6 +32,12 @@ if is_tf_available():
logger = logging.get_logger(__name__)
DEPRECATION_WARNING = (
"This {0} will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
"library. You can have a look at this example script for pointers: "
"https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_glue.py"
)
def glue_convert_examples_to_features(
examples: Union[List[InputExample], "tf.data.Dataset"],
@ -57,6 +64,7 @@ def glue_convert_examples_to_features(
``InputFeatures`` which can be fed to the model.
"""
warnings.warn(DEPRECATION_WARNING.format("function"), FutureWarning)
if is_tf_available() and isinstance(examples, tf.data.Dataset):
if task is None:
raise ValueError("When calling glue_convert_examples_to_features from TF, the task parameter is required.")
@ -162,6 +170,10 @@ class OutputMode(Enum):
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
@ -205,6 +217,10 @@ class MrpcProcessor(DataProcessor):
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
@ -247,6 +263,10 @@ class MnliProcessor(DataProcessor):
class MnliMismatchedProcessor(MnliProcessor):
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
@ -259,6 +279,10 @@ class MnliMismatchedProcessor(MnliProcessor):
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
@ -302,6 +326,10 @@ class ColaProcessor(DataProcessor):
class Sst2Processor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
@ -344,6 +372,10 @@ class Sst2Processor(DataProcessor):
class StsbProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
@ -386,6 +418,10 @@ class StsbProcessor(DataProcessor):
class QqpProcessor(DataProcessor):
"""Processor for the QQP data set (GLUE version)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
@ -434,6 +470,10 @@ class QqpProcessor(DataProcessor):
class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
@ -476,6 +516,10 @@ class QnliProcessor(DataProcessor):
class RteProcessor(DataProcessor):
"""Processor for the RTE data set (GLUE version)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
@ -518,6 +562,10 @@ class RteProcessor(DataProcessor):
class WnliProcessor(DataProcessor):
"""Processor for the WNLI data set (GLUE version)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(