More tests to Trainer (#6699)

* More tests to Trainer

* Add warning in the doc
This commit is contained in:
Sylvain Gugger 2020-08-25 07:07:36 -04:00 committed by GitHub
parent f5bad031bc
commit abc0202194
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 106 additions and 15 deletions

View File

@ -77,6 +77,7 @@ jobs:
- v0.3-torch_and_tf-{{ checksum "setup.py" }}
- v0.3-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install git+https://github.com/huggingface/nlp
- run: pip install .[sklearn,tf-cpu,torch,testing]
- run: pip install codecov pytest-cov
- save_cache:
@ -103,6 +104,7 @@ jobs:
- v0.3-torch-{{ checksum "setup.py" }}
- v0.3-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install git+https://github.com/huggingface/nlp
- run: pip install .[sklearn,torch,testing]
- save_cache:
key: v0.3-torch-{{ checksum "setup.py" }}
@ -127,6 +129,7 @@ jobs:
- v0.3-tf-{{ checksum "setup.py" }}
- v0.3-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install git+https://github.com/huggingface/nlp
- run: pip install .[sklearn,tf-cpu,testing]
- save_cache:
key: v0.3-tf-{{ checksum "setup.py" }}

View File

@ -206,22 +206,29 @@ class Trainer:
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
**kwargs,
):
if args is None:
logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
args = TrainingArguments("tmp_trainer")
self.args = args
# Seed must be set before instantiating the model when using model
set_seed(self.args.seed)
assert (
model is not None or model_init is not None
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
if model is None and model_init is not None:
model = model_init()
self.model = model.to(args.device) if model is not None else None
if args is None:
logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
args = TrainingArguments("tmp_trainer")
self.args = args
self.data_collator = data_collator if data_collator is not None else default_data_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.model_init = model_init
self.compute_metrics = compute_metrics
self.optimizer, self.lr_scheduler = optimizers
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
raise RuntimeError(
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
self.tb_writer = tb_writer
if "prediction_loss_only" in kwargs:
warnings.warn(
@ -251,7 +258,6 @@ class Trainer:
"To use comet_ml logging, run `pip/conda install comet_ml` "
"see https://www.comet.ml/docs/python-sdk/huggingface/"
)
set_seed(self.args.seed)
# Create output directory if needed
if self.is_world_process_zero():
os.makedirs(self.args.output_dir, exist_ok=True)
@ -542,12 +548,18 @@ class Trainer:
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
The trial run or the hyperparameter dictionary for hyperparameter search.
"""
# This might change the seed so needs to run first.
self._hp_search_setup(trial)
# Model re-init
if self.model_init is not None:
# Seed must be set before instantiating the model when using model_init.
set_seed(self.args.seed)
model = self.model_init()
self.model = model.to(self.args.device)
self._hp_search_setup(trial)
# Reinitializes optimizer and scheduler
self.optimizer, self.lr_scheduler = None, None
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
@ -788,6 +800,13 @@ class Trainer:
:obj:`compute_objectie`, which defaults to a function returning the evaluation loss when no metric is provided,
the sum of all metrics otherwise.
.. warning::
To use this method, you need to have provided a ``model_init`` when initializing your
:class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.
Args:
hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
A function that defines the hyperparameter search space. Will default to
@ -825,20 +844,22 @@ class Trainer:
)
backend = HPSearchBackend(backend)
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
raise RuntimeError(" You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
if backend == HPSearchBackend.RAY and not is_ray_available():
raise RuntimeError(
" You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
)
self.hp_search_backend = backend
if self.model_init is None:
raise RuntimeError(
"To use hyperparameter search, you need to pass your model through a model_init function."
)
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
def _objective(trial):
# To make sure optimizer and lr_scheduler are reset with the new choices of HPs
self.optimizer = None
self.lr_scheduler = None
self.objective = None
self.train(trial=trial)
# If there hasn't been any evaluation during the training loop.

View File

@ -1,5 +1,6 @@
import unittest
import nlp
import numpy as np
from transformers import AutoTokenizer, TrainingArguments, is_torch_available
@ -93,6 +94,17 @@ if is_torch_available():
@require_torch
class TrainerIntegrationTest(unittest.TestCase):
def check_trained_model(self, model, alternate_seed=False):
# Checks a training seeded with learning_rate = 0.1
if alternate_seed:
# With args.seed = 314
self.assertTrue(torch.abs(model.a - 1.0171) < 1e-4)
self.assertTrue(torch.abs(model.b - 1.2494) < 1e-4)
else:
# With default args.seed
self.assertTrue(torch.abs(model.a - 0.6975) < 1e-4)
self.assertTrue(torch.abs(model.b - 1.2415) < 1e-4)
def setUp(self):
# Get the default values (in case they change):
args = TrainingArguments(".")
@ -103,14 +115,12 @@ class TrainerIntegrationTest(unittest.TestCase):
# Checks that training worked, model trained and seed made a reproducible training.
trainer = get_regression_trainer(learning_rate=0.1)
trainer.train()
self.assertTrue(torch.abs(trainer.model.a - 0.6975) < 1e-4)
self.assertTrue(torch.abs(trainer.model.b - 1.2415) < 1e-4)
self.check_trained_model(trainer.model)
# Checks that a different seed gets different (reproducible) results.
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
trainer.train()
self.assertTrue(torch.abs(trainer.model.a - 1.0171) < 1e-4)
self.assertTrue(torch.abs(trainer.model.b - 1.2494) < 1e-4)
self.check_trained_model(trainer.model, alternate_seed=True)
def test_number_of_steps_in_training(self):
# Regular training has n_epochs * len(train_dl) steps
@ -190,6 +200,63 @@ class TrainerIntegrationTest(unittest.TestCase):
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
def test_trainer_with_nlp(self):
np.random.seed(42)
x = np.random.normal(size=(64,)).astype(np.float32)
y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,))
train_dataset = nlp.Dataset.from_dict({"input_x": x, "label": y})
# Base training. Should have the same results as test_reproducible_training
model = RegressionModel()
args = TrainingArguments("./regression", learning_rate=0.1)
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.train()
self.check_trained_model(trainer.model)
# Can return tensors.
train_dataset.set_format(type="torch")
model = RegressionModel()
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.train()
self.check_trained_model(trainer.model)
# Adding one column not used by the model should have no impact
z = np.random.normal(size=(64,)).astype(np.float32)
train_dataset = nlp.Dataset.from_dict({"input_x": x, "label": y, "extra": z})
model = RegressionModel()
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.train()
self.check_trained_model(trainer.model)
def test_custom_optimizer(self):
train_dataset = RegressionDataset()
args = TrainingArguments("./regression")
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0)
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
trainer.train()
self.assertTrue(torch.abs(trainer.model.a - 1.8950) < 1e-4)
self.assertTrue(torch.abs(trainer.model.b - 2.5656) < 1e-4)
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
def test_model_init(self):
train_dataset = RegressionDataset()
args = TrainingArguments("./regression", learning_rate=0.1)
trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
trainer.train()
self.check_trained_model(trainer.model)
# Re-training should restart from scratch, thus lead the same results.
trainer.train()
self.check_trained_model(trainer.model)
# Re-training should restart from scratch, thus lead the same results and new seed should be used.
trainer.args.seed = 314
trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True)
def test_trainer_eval_mrpc(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)