updated example template (#12365)

This commit is contained in:
Bhadresh Savani 2021-06-26 04:50:30 +01:00 committed by GitHub
parent 539ee456d4
commit 9a7545943d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 16 deletions

View File

@ -27,6 +27,7 @@ import sys
from dataclasses import dataclass, field
from typing import Optional
import datasets
from datasets import load_dataset
import transformers
@ -226,16 +227,19 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if training_args.should_log:
transformers.utils.logging.set_verbosity_info()
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
@ -252,7 +256,7 @@ def main():
# download the dataset.
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
raw_datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
else:
data_files = {}
if data_args.train_file is not None:
@ -266,7 +270,7 @@ def main():
extension = data_args.test_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files)
raw_datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
@ -348,20 +352,20 @@ def main():
# Preprocessing the datasets.
# First we tokenize all the texts.
if training_args.do_train:
column_names = datasets["train"].column_names
column_names = raw_datasets["train"].column_names
elif training_args.do_eval:
column_names = datasets["validation"].column_names
column_names = raw_datasets["validation"].column_names
elif training_args.do_predict:
column_names = datasets["test"].column_names
column_names = raw_datasets["test"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
def tokenize_function(examples):
return tokenizer(examples[text_column_name], padding="max_length", truncation=True)
if training_args.do_train:
if "train" not in datasets:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = datasets["train"]
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
# Select Sample from Dataset
train_dataset = train_dataset.select(range(data_args.max_train_samples))
@ -375,9 +379,9 @@ def main():
)
if training_args.do_eval:
if "validation" not in datasets:
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
eval_dataset = raw_datasets["validation"]
# Selecting samples from dataset
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
@ -391,9 +395,9 @@ def main():
)
if training_args.do_predict:
if "test" not in datasets:
if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset")
predict_dataset = datasets["test"]
predict_dataset = raw_datasets["test"]
# Selecting samples from dataset
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
@ -754,7 +758,7 @@ def main():
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = datasets["train"].column_names
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
padding = "max_length" if args.pad_to_max_length else False