Add hyperparameter search to Trainer (#6576)
* Add optuna hyperparameter search to Trainer * @julien-c suggestions Co-authored-by: Julien Chaumond <chaumond@gmail.com> * Make compute_objective an arg function * Formatting * Rework to make it easier to add ray * Formatting * Initial support for Ray * Formatting * Polish and finalize * Add trial id to checkpoint with Ray * Smaller default * Use GPU in ray if available * Formatting * Fix test * Update install instruction Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Address review comments * Formatting post-merge Co-authored-by: Julien Chaumond <chaumond@gmail.com> Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
dd522da004
commit
3a7fdd3f52
|
@ -92,7 +92,13 @@ from .file_utils import (
|
|||
from .hf_argparser import HfArgumentParser
|
||||
|
||||
# Integrations
|
||||
from .integrations import is_comet_available, is_tensorboard_available, is_wandb_available
|
||||
from .integrations import (
|
||||
is_comet_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_tensorboard_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
|
||||
# Model Cards
|
||||
from .modelcard import ModelCard
|
||||
|
|
|
@ -35,6 +35,20 @@ except ImportError:
|
|||
except ImportError:
|
||||
_has_tensorboard = False
|
||||
|
||||
try:
|
||||
import optuna # noqa: F401
|
||||
|
||||
_has_optuna = True
|
||||
except (ImportError):
|
||||
_has_optuna = False
|
||||
|
||||
try:
|
||||
import ray # noqa: F401
|
||||
|
||||
_has_ray = True
|
||||
except (ImportError):
|
||||
_has_ray = False
|
||||
|
||||
|
||||
def is_wandb_available():
|
||||
return _has_wandb
|
||||
|
@ -46,3 +60,18 @@ def is_comet_available():
|
|||
|
||||
def is_tensorboard_available():
|
||||
return _has_tensorboard
|
||||
|
||||
|
||||
def is_optuna_available():
|
||||
return _has_optuna
|
||||
|
||||
|
||||
def is_ray_available():
|
||||
return _has_ray
|
||||
|
||||
|
||||
def default_hp_search_backend():
|
||||
if is_optuna_available():
|
||||
return "optuna"
|
||||
elif is_ray_available():
|
||||
return "ray"
|
||||
|
|
|
@ -21,10 +21,27 @@ from tqdm.auto import tqdm, trange
|
|||
|
||||
from .data.data_collator import DataCollator, default_data_collator
|
||||
from .file_utils import is_nlp_available, is_torch_tpu_available
|
||||
from .integrations import is_comet_available, is_tensorboard_available, is_wandb_available
|
||||
from .integrations import (
|
||||
default_hp_search_backend,
|
||||
is_comet_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_tensorboard_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput, set_seed
|
||||
from .trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
BestRun,
|
||||
EvalPrediction,
|
||||
HPSearchBackend,
|
||||
PredictionOutput,
|
||||
TrainOutput,
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
set_seed,
|
||||
)
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
|
@ -62,6 +79,12 @@ if is_wandb_available():
|
|||
if is_comet_available():
|
||||
import comet_ml
|
||||
|
||||
if is_optuna_available():
|
||||
import optuna
|
||||
|
||||
if is_ray_available():
|
||||
from ray import tune
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -140,10 +163,11 @@ class Trainer:
|
|||
optimized for 🤗 Transformers.
|
||||
|
||||
Args:
|
||||
model (:class:`~transformers.PreTrainedModel`):
|
||||
The model to train, evaluate or use for predictions.
|
||||
args (:class:`~transformers.TrainingArguments`):
|
||||
The arguments to tweak for training.
|
||||
model (:class:`~transformers.PreTrainedModel`, `optional`):
|
||||
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
|
||||
args (:class:`~transformers.TrainingArguments`, `optional`):
|
||||
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
|
||||
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
|
||||
data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`):
|
||||
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
|
||||
:obj:`eval_dataset`.
|
||||
|
@ -151,8 +175,11 @@ class Trainer:
|
|||
The dataset to use for training. If it is an :obj:`nlp.Dataset`, columns not accepted by the
|
||||
``model.forward()`` method are automatically removed.
|
||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||
The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
|
||||
The dataset to use for evaluation. If it is an :obj:`nlp.Dataset`, columns not accepted by the
|
||||
``model.forward()`` method are automatically removed.
|
||||
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
|
||||
A function that instantiates the model to be used. If provided, each call to
|
||||
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
|
||||
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
|
||||
The function that will be used to compute metrics at evaluation. Must take a
|
||||
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
|
||||
|
@ -168,21 +195,31 @@ class Trainer:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
args: TrainingArguments,
|
||||
model: PreTrainedModel = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
tb_writer: Optional["SummaryWriter"] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model.to(args.device)
|
||||
assert (
|
||||
model is not None or model_init is not None
|
||||
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
||||
if model is None and model_init is not None:
|
||||
model = model_init()
|
||||
self.model = model.to(args.device) if model is not None else None
|
||||
if args is None:
|
||||
logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
|
||||
args = TrainingArguments("tmp_trainer")
|
||||
self.args = args
|
||||
self.data_collator = data_collator if data_collator is not None else default_data_collator
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.model_init = model_init
|
||||
self.compute_metrics = compute_metrics
|
||||
self.optimizer, self.lr_scheduler = optimizers
|
||||
self.tb_writer = tb_writer
|
||||
|
@ -242,6 +279,7 @@ class Trainer:
|
|||
self.epoch = None
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
self.hp_search_backend = None
|
||||
|
||||
def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None):
|
||||
if not self.args.remove_unused_columns:
|
||||
|
@ -462,7 +500,38 @@ class Trainer:
|
|||
"""
|
||||
return len(dataloader.dataset)
|
||||
|
||||
def train(self, model_path: Optional[str] = None):
|
||||
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
||||
""" HP search setup code """
|
||||
if self.hp_search_backend is None or trial is None:
|
||||
return
|
||||
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
|
||||
for key, value in params.items():
|
||||
if not hasattr(self.args, key):
|
||||
raise AttributeError(
|
||||
f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
|
||||
)
|
||||
old_attr = getattr(self.args, key, None)
|
||||
# Casting value to the proper type
|
||||
if old_attr is not None:
|
||||
value = type(old_attr)(value)
|
||||
setattr(self.args, key, value)
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
logger.info("Trial:", trial.params)
|
||||
|
||||
def _report_to_hp_search(
|
||||
self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
|
||||
):
|
||||
if self.hp_search_backend is None or trial is None:
|
||||
return
|
||||
self.objective = self.compute_objective(metrics)
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
trial.report(self.objective, epoch)
|
||||
if trial.should_prune():
|
||||
raise optuna.TrialPruned()
|
||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||
tune.report(objective=self.objective, **metrics)
|
||||
|
||||
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
|
||||
"""
|
||||
Main training entry point.
|
||||
|
||||
|
@ -470,7 +539,17 @@ class Trainer:
|
|||
model_path (:obj:`str`, `optional`):
|
||||
Local path to the model if the model to train has been instantiated from a local path. If present,
|
||||
training will resume from the optimizer/scheduler states loaded here.
|
||||
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
|
||||
The trial run or the hyperparameter dictionary for hyperparameter search.
|
||||
"""
|
||||
# Model re-init
|
||||
if self.model_init is not None:
|
||||
model = self.model_init()
|
||||
self.model = model.to(self.args.device)
|
||||
|
||||
self._hp_search_setup(trial)
|
||||
|
||||
# Data loader and number of training steps
|
||||
train_dataloader = self.get_train_dataloader()
|
||||
if self.args.max_steps > 0:
|
||||
t_total = self.args.max_steps
|
||||
|
@ -561,9 +640,8 @@ class Trainer:
|
|||
tr_loss = 0.0
|
||||
logging_loss = 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=not self.is_local_process_zero()
|
||||
)
|
||||
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
|
||||
train_iterator = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
|
||||
for epoch in train_iterator:
|
||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||
train_dataloader.sampler.set_epoch(epoch)
|
||||
|
@ -572,9 +650,9 @@ class Trainer:
|
|||
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
|
||||
self.args.device
|
||||
)
|
||||
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero())
|
||||
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=disable_tqdm)
|
||||
else:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_process_zero())
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=disable_tqdm)
|
||||
|
||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||
if self.args.past_index >= 0:
|
||||
|
@ -631,7 +709,8 @@ class Trainer:
|
|||
self.log(logs)
|
||||
|
||||
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
||||
self.evaluate()
|
||||
metrics = self.evaluate()
|
||||
self._report_to_hp_search(trial, epoch, metrics)
|
||||
|
||||
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
|
||||
# In all cases (even distributed/parallel), self.model is always a reference
|
||||
|
@ -643,7 +722,15 @@ class Trainer:
|
|||
else:
|
||||
assert model is self.model, f"Model {model} should be a reference to self.model"
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
|
||||
if self.hp_search_backend is not None and trial is not None:
|
||||
run_id = (
|
||||
trial.number
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA
|
||||
else tune.get_trial_id()
|
||||
)
|
||||
checkpoint_folder += f"-run-{run_id}"
|
||||
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
||||
|
||||
self.save_model(output_dir)
|
||||
|
||||
|
@ -683,6 +770,108 @@ class Trainer:
|
|||
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
||||
return TrainOutput(self.global_step, tr_loss / self.global_step)
|
||||
|
||||
def hyperparameter_search(
|
||||
self,
|
||||
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
||||
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
|
||||
n_trials: int = 20,
|
||||
timeout: int = 1800,
|
||||
n_jobs: int = 1,
|
||||
direction: str = "minimize",
|
||||
backend: Optional[Union["str", HPSearchBackend]] = None,
|
||||
**kwargs
|
||||
) -> BestRun:
|
||||
"""
|
||||
Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by the
|
||||
method, which is the evaluation loss when no metric is provided, the sum of all metrics otherwise (you can
|
||||
change that behavior by subclassing and overriding this method).
|
||||
|
||||
Args:
|
||||
hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
|
||||
A function that defines the hyperparameter search space. Will default to
|
||||
:func:`~transformers.trainer_utils.default_hp_space_optuna` or
|
||||
:func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
|
||||
compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
|
||||
A function computing the objective to minimize or maximize from the metrics returned by the
|
||||
:obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
|
||||
n_trials (:obj:`int`, `optional`, defaults to 100):
|
||||
The number of trial runs to test.
|
||||
direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
|
||||
Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
|
||||
pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
|
||||
several metrics.
|
||||
backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
|
||||
The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
|
||||
one is installed. If both are installed, will default to optuna.
|
||||
kwargs:
|
||||
Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
|
||||
more information see:
|
||||
|
||||
- the documentation of `optuna.create_stufy <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
|
||||
- the documentation of `tune.run <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
|
||||
|
||||
Returns:
|
||||
:class:`transformers.trainer_utils.BestRun`: All the informations about the best run.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = default_hp_search_backend()
|
||||
if backend is None:
|
||||
raise RuntimeError(
|
||||
"At least one of optuna or ray should be installed. "
|
||||
"To install optuna run `pip install optuna`."
|
||||
"To install ray run `pip install ray[tune]`."
|
||||
)
|
||||
backend = HPSearchBackend(backend)
|
||||
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
|
||||
raise RuntimeError(" You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
|
||||
if backend == HPSearchBackend.RAY and not is_ray_available():
|
||||
raise RuntimeError(
|
||||
" You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
||||
)
|
||||
self.hp_search_backend = backend
|
||||
|
||||
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
|
||||
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
|
||||
|
||||
def _objective(trial):
|
||||
# To make sure optimizer and lr_scheduler are reset with the new choices of HPs
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.objective = None
|
||||
self.train(trial=trial)
|
||||
# If there hasn't been any evaluation during the training loop.
|
||||
if getattr(self, "objective", None) is None:
|
||||
metrics = self.evaluate()
|
||||
self.objective = self.compute_objective(metrics)
|
||||
return self.objective
|
||||
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
n_jobs = kwargs.pop("n_jobs", 1)
|
||||
study = optuna.create_study(direction=direction, **kwargs)
|
||||
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
|
||||
best_trial = study.best_trial
|
||||
best_run = BestRun(str(best_trial.number), best_trial.value, best_trial.params)
|
||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||
# The TensorBoard writer does not pickle so we have to remove it (if it exists) while doing the ray hp
|
||||
# search.
|
||||
_tb_writer = self.tb_writer
|
||||
self.tb_writer = None
|
||||
# Setup default `resources_per_trial` and `reporter`.
|
||||
if "resources_per_trial" not in kwargs and self.args.n_gpu > 0:
|
||||
kwargs["resources_per_trial"] = {"gpu": self.args.n_gpu}
|
||||
if "reporter" not in kwargs:
|
||||
from ray.tune import CLIReporter
|
||||
|
||||
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
|
||||
analysis = tune.run(_objective, config=self.hp_space(None), num_samples=n_trials, **kwargs)
|
||||
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
|
||||
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
|
||||
self.tb_writer = _tb_writer
|
||||
|
||||
self.hp_search_backend = None
|
||||
return best_run
|
||||
|
||||
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
|
||||
"""
|
||||
Log :obj:`logs` on the various objects watching training.
|
||||
|
@ -1020,8 +1209,9 @@ class Trainer:
|
|||
if self.args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
|
||||
samples_count = 0
|
||||
for inputs in tqdm(dataloader, desc=description):
|
||||
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
||||
samples_count += batch_size
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import random
|
||||
from typing import Dict, NamedTuple, Optional
|
||||
from typing import Any, Dict, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .file_utils import is_tf_available, is_torch_available
|
||||
from .integrations import is_ray_available
|
||||
from .tokenization_utils_base import ExplicitEnum
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
|
@ -53,3 +55,70 @@ class TrainOutput(NamedTuple):
|
|||
|
||||
|
||||
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
||||
|
||||
|
||||
class BestRun(NamedTuple):
|
||||
"""
|
||||
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
|
||||
|
||||
Parameters:
|
||||
run_id (:obj:`str`):
|
||||
The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending
|
||||
with run-{run_id}).
|
||||
objective (:obj:`float`):
|
||||
The objective that was obtained for this run.
|
||||
hyperparameters (:obj:`Dict[str, Any]`):
|
||||
The hyperparameters picked to get this run.
|
||||
"""
|
||||
|
||||
run_id: str
|
||||
objective: float
|
||||
hyperparameters: Dict[str, Any]
|
||||
|
||||
|
||||
def default_compute_objective(metrics: Dict[str, float]) -> float:
|
||||
"""
|
||||
The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no
|
||||
metrics are provided to the :class:`~transformers.Trainer`, the sum of all metrics otherwise.
|
||||
|
||||
Args:
|
||||
metrics (:obj:`Dict[str, float]`): The metrics returned by the evaluate method.
|
||||
|
||||
Return:
|
||||
:obj:`float`: The objective to minimize or maximize
|
||||
"""
|
||||
loss = metrics.pop("eval_loss", None)
|
||||
_ = metrics.pop("epoch", None)
|
||||
return loss if len(metrics) == 0 else sum(metrics.values())
|
||||
|
||||
|
||||
def default_hp_space_optuna(trial) -> Dict[str, float]:
|
||||
return {
|
||||
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
|
||||
"num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
|
||||
"seed": trial.suggest_int("seed", 1, 40),
|
||||
"per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
|
||||
}
|
||||
|
||||
|
||||
def default_hp_space_ray(trial) -> Dict[str, float]:
|
||||
assert is_ray_available(), "This function needs ray installed: `pip install ray[tune]`"
|
||||
from ray import tune
|
||||
|
||||
return {
|
||||
"learning_rate": tune.loguniform(1e-6, 1e-4),
|
||||
"num_train_epochs": tune.choice(range(1, 6)),
|
||||
"seed": tune.uniform(1, 40),
|
||||
"per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
|
||||
}
|
||||
|
||||
|
||||
class HPSearchBackend(ExplicitEnum):
|
||||
OPTUNA = "optuna"
|
||||
RAY = "ray"
|
||||
|
||||
|
||||
default_hp_space = {
|
||||
HPSearchBackend.OPTUNA: default_hp_space_optuna,
|
||||
HPSearchBackend.RAY: default_hp_space_ray,
|
||||
}
|
||||
|
|
|
@ -114,6 +114,9 @@ class TrainingArguments:
|
|||
at the next training step under the keyword argument ``mems``.
|
||||
run_name (:obj:`str`, `optional`):
|
||||
A descriptor for the run. Notably used for wandb logging.
|
||||
disable_tqdm (:obj:`bool`, `optional`):
|
||||
Whether or not to disable the tqdm progress bars. Will default to :obj:`True` if the logging level is set
|
||||
to warn or lower (default), :obj:`False` otherwise.
|
||||
remove_unused_columns (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If using `nlp.Dataset` datasets, whether or not to automatically remove the columns unused by the model
|
||||
forward method.
|
||||
|
@ -238,6 +241,13 @@ class TrainingArguments:
|
|||
run_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
|
||||
)
|
||||
disable_tqdm: Optional[bool] = field(
|
||||
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.disable_tqdm is None:
|
||||
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||
|
||||
remove_unused_columns: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
|
||||
|
|
Loading…
Reference in New Issue