Add SigOpt HPO to transformers trainer api (#13572)

* add sigopt hpo to transformers.

Signed-off-by: Ding, Ke <ke.ding@intel.com>

* extend sigopt changes to test code and others..

Signed-off-by: Ding, Ke <ke.ding@intel.com>

* Style.

* fix style for sigopt integration.

Signed-off-by: Ding, Ke <ke.ding@intel.com>

* Add necessary information to run unittests on SigOpt.

Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
kding1 2021-09-23 08:01:51 -07:00 committed by GitHub
parent 62832c962f
commit 6a3a197fcd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 168 additions and 15 deletions

View File

@ -15,6 +15,7 @@ env:
OMP_NUM_THREADS: 16
MKL_NUM_THREADS: 16
PYTEST_TIMEOUT: 600
SIGOPT_API_TOKEN: ${{ secrets.SIGOPT_API_TOKEN }}
jobs:
run_all_tests_torch_gpu:

View File

@ -15,6 +15,7 @@ env:
OMP_NUM_THREADS: 16
MKL_NUM_THREADS: 16
PYTEST_TIMEOUT: 600
SIGOPT_API_TOKEN: ${{ secrets.SIGOPT_API_TOKEN }}
jobs:
run_all_tests_torch_gpu:

View File

@ -135,6 +135,7 @@ _deps = [
"sagemaker>=2.31.0",
"scikit-learn",
"sentencepiece>=0.1.91,!=0.1.92",
"sigopt",
"soundfile",
"sphinx-copybutton",
"sphinx-markdown-tables",
@ -248,8 +249,9 @@ extras["deepspeed"] = deps_list("deepspeed")
extras["fairscale"] = deps_list("fairscale")
extras["optuna"] = deps_list("optuna")
extras["ray"] = deps_list("ray[tune]")
extras["sigopt"] = deps_list("sigopt")
extras["integrations"] = extras["optuna"] + extras["ray"]
extras["integrations"] = extras["optuna"] + extras["ray"]+ extras["sigopt"]
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
extras["audio"] = deps_list("soundfile")

View File

@ -130,6 +130,7 @@ _import_structure = {
"is_optuna_available",
"is_ray_available",
"is_ray_tune_available",
"is_sigopt_available",
"is_tensorboard_available",
"is_wandb_available",
],
@ -1951,6 +1952,7 @@ if TYPE_CHECKING:
is_optuna_available,
is_ray_available,
is_ray_tune_available,
is_sigopt_available,
is_tensorboard_available,
is_wandb_available,
)

View File

@ -53,6 +53,7 @@ deps = {
"sagemaker": "sagemaker>=2.31.0",
"scikit-learn": "scikit-learn",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"sigopt": "sigopt",
"soundfile": "soundfile",
"sphinx-copybutton": "sphinx-copybutton",
"sphinx-markdown-tables": "sphinx-markdown-tables",

View File

@ -83,6 +83,10 @@ def is_ray_tune_available():
return importlib.util.find_spec("ray.tune") is not None
def is_sigopt_available():
return importlib.util.find_spec("sigopt") is not None
def is_azureml_available():
if importlib.util.find_spec("azureml") is None:
return False
@ -117,6 +121,10 @@ def hp_params(trial):
if isinstance(trial, dict):
return trial
if is_sigopt_available():
if isinstance(trial, dict):
return trial
raise RuntimeError(f"Unknown type for trial {trial.__class__}")
@ -125,6 +133,8 @@ def default_hp_search_backend():
return "optuna"
elif is_ray_tune_available():
return "ray"
elif is_sigopt_available():
return "sigopt"
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
@ -288,6 +298,45 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
return best_run
def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
from sigopt import Connection
conn = Connection()
proxies = kwargs.pop("proxies", None)
if proxies is not None:
conn.set_proxies(proxies)
experiment = conn.experiments().create(
name="huggingface-tune",
parameters=trainer.hp_space(None),
metrics=[dict(name="objective", objective=direction, strategy="optimize")],
parallel_bandwidth=1,
observation_budget=n_trials,
project="huggingface",
)
logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
while experiment.progress.observation_count < experiment.observation_budget:
suggestion = conn.experiments(experiment.id).suggestions().create()
trainer.objective = None
trainer.train(resume_from_checkpoint=None, trial=suggestion)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
values = [dict(name="objective", value=trainer.objective)]
obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
experiment = conn.experiments(experiment.id).fetch()
best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
best_run = BestRun(best.id, best.value, best.assignments)
return best_run
def get_available_reporting_integrations():
integrations = []
if is_azureml_available():

View File

@ -51,7 +51,7 @@ from .file_utils import (
is_torchaudio_available,
is_vision_available,
)
from .integrations import is_optuna_available, is_ray_available
from .integrations import is_optuna_available, is_ray_available, is_sigopt_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
@ -511,6 +511,19 @@ def require_ray(test_case):
return test_case
def require_sigopt(test_case):
"""
Decorator marking a test that requires SigOpt.
These tests are skipped when SigOpt isn't installed.
"""
if not is_sigopt_available():
return unittest.skip("test requires SigOpt")(test_case)
else:
return test_case
def require_soundfile(test_case):
"""
Decorator marking a test that requires soundfile

View File

@ -40,8 +40,10 @@ from .integrations import ( # isort: split
is_fairscale_available,
is_optuna_available,
is_ray_tune_available,
is_sigopt_available,
run_hp_search_optuna,
run_hp_search_ray,
run_hp_search_sigopt,
)
import numpy as np
@ -231,9 +233,9 @@ class Trainer:
A function that instantiates the model to be used. If provided, each call to
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be
able to choose different architectures according to hyper parameters (such as layer count, sizes of inner
layers, dropout probabilities etc).
The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
be able to choose different architectures according to hyper parameters (such as layer count, sizes of
inner layers, dropout probabilities etc).
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take a
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
@ -869,6 +871,8 @@ class Trainer:
elif self.hp_search_backend == HPSearchBackend.RAY:
params = trial
params.pop("wandb", None)
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
for key, value in params.items():
if not hasattr(self.args, key):
@ -883,6 +887,8 @@ class Trainer:
setattr(self.args, key, value)
if self.hp_search_backend == HPSearchBackend.OPTUNA:
logger.info("Trial:", trial.params)
if self.hp_search_backend == HPSearchBackend.SIGOPT:
logger.info(f"SigOpt Assignments: {trial.assignments}")
if self.args.deepspeed:
# Rebuild the deepspeed config to reflect the updated training parameters
from transformers.deepspeed import HfDeepSpeedConfig
@ -1232,7 +1238,7 @@ class Trainer:
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
self.state.trial_params = hp_params(trial) if trial is not None else None
self.state.trial_params = hp_params(trial.assignments) if trial is not None else None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
@ -1524,10 +1530,12 @@ class Trainer:
if self.hp_search_backend is not None and trial is not None:
if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
else:
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
run_id = tune.get_trial_id()
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
run_id = trial.id
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
run_dir = os.path.join(self.args.output_dir, run_name)
else:
@ -1671,9 +1679,9 @@ class Trainer:
**kwargs,
) -> BestRun:
"""
Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
:obj:`compute_objective`, which defaults to a function returning the evaluation loss when no metric is
provided, the sum of all metrics otherwise.
Launch an hyperparameter search using ``optuna`` or ``Ray Tune`` or ``SigOpt``. The optimized quantity is
determined by :obj:`compute_objective`, which defaults to a function returning the evaluation loss when no
metric is provided, the sum of all metrics otherwise.
.. warning::
@ -1686,7 +1694,8 @@ class Trainer:
hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
A function that defines the hyperparameter search space. Will default to
:func:`~transformers.trainer_utils.default_hp_space_optuna` or
:func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
:func:`~transformers.trainer_utils.default_hp_space_ray` or
:func:`~transformers.trainer_utils.default_hp_space_sigopt` depending on your backend.
compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
A function computing the objective to minimize or maximize from the metrics returned by the
:obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
@ -1697,8 +1706,8 @@ class Trainer:
pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
several metrics.
backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
one is installed. If both are installed, will default to optuna.
The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
on which one is installed. If all are installed, will default to optuna.
kwargs:
Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
more information see:
@ -1707,6 +1716,7 @@ class Trainer:
<https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html>`__
- the documentation of `tune.run
<https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
- the documentation of `sigopt <https://app.sigopt.com/docs/endpoints/experiments/create>`__
Returns:
:class:`transformers.trainer_utils.BestRun`: All the information about the best run.
@ -1718,6 +1728,7 @@ class Trainer:
"At least one of optuna or ray should be installed. "
"To install optuna run `pip install optuna`."
"To install ray run `pip install ray[tune]`."
"To install sigopt run `pip install sigopt`."
)
backend = HPSearchBackend(backend)
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
@ -1726,6 +1737,8 @@ class Trainer:
raise RuntimeError(
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
)
if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
self.hp_search_backend = backend
if self.model_init is None:
raise RuntimeError(
@ -1736,8 +1749,12 @@ class Trainer:
self.hp_name = hp_name
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
best_run = run_hp_search(self, n_trials, direction, **kwargs)
backend_dict = {
HPSearchBackend.OPTUNA: run_hp_search_optuna,
HPSearchBackend.RAY: run_hp_search_ray,
HPSearchBackend.SIGOPT: run_hp_search_sigopt,
}
best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
self.hp_search_backend = None
return best_run

View File

@ -198,14 +198,29 @@ def default_hp_space_ray(trial) -> Dict[str, float]:
}
def default_hp_space_sigopt(trial):
return [
{"bounds": {"min": 1e-6, "max": 1e-4}, "name": "learning_rate", "type": "double", "transformamtion": "log"},
{"bounds": {"min": 1, "max": 6}, "name": "num_train_epochs", "type": "int"},
{"bounds": {"min": 1, "max": 40}, "name": "seed", "type": "int"},
{
"categorical_values": ["4", "8", "16", "32", "64"],
"name": "per_device_train_batch_size",
"type": "categorical",
},
]
class HPSearchBackend(ExplicitEnum):
OPTUNA = "optuna"
RAY = "ray"
SIGOPT = "sigopt"
default_hp_space = {
HPSearchBackend.OPTUNA: default_hp_space_optuna,
HPSearchBackend.RAY: default_hp_space_ray,
HPSearchBackend.SIGOPT: default_hp_space_sigopt,
}

View File

@ -50,6 +50,7 @@ from transformers.testing_utils import (
require_optuna,
require_ray,
require_sentencepiece,
require_sigopt,
require_tokenizers,
require_torch,
require_torch_gpu,
@ -1522,3 +1523,54 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
with ray_start_client_server():
assert ray.util.client.ray.is_connected()
self.ray_hyperparameter_search()
@require_torch
@require_sigopt
class TrainerHyperParameterSigOptIntegrationTest(unittest.TestCase):
def setUp(self):
args = TrainingArguments(".")
self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size
def test_hyperparameter_search(self):
class MyTrialShortNamer(TrialShortNamer):
DEFAULTS = {"a": 0, "b": 0}
def hp_space(trial):
return [
{"bounds": {"min": -4, "max": 4}, "name": "a", "type": "int"},
{"bounds": {"min": -4, "max": 4}, "name": "b", "type": "int"},
]
def model_init(trial):
if trial is not None:
a = trial.assignments["a"]
b = trial.assignments["b"]
else:
a = 0
b = 0
config = RegressionModelConfig(a=a, b=b, double_output=False)
return RegressionPreTrainedModel(config)
def hp_name(trial):
return MyTrialShortNamer.shortname(trial.assignments)
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=tmp_dir,
learning_rate=0.1,
logging_steps=1,
evaluation_strategy=IntervalStrategy.EPOCH,
save_strategy=IntervalStrategy.EPOCH,
num_train_epochs=4,
disable_tqdm=True,
load_best_model_at_end=True,
logging_dir="runs",
run_name="test",
model_init=model_init,
)
trainer.hyperparameter_search(
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="sigopt", n_trials=4
)