From 6a3a197fcd1f3ac9472cbe05e373ea11839f5d5d Mon Sep 17 00:00:00 2001 From: kding1 Date: Thu, 23 Sep 2021 08:01:51 -0700 Subject: [PATCH] Add SigOpt HPO to transformers trainer api (#13572) * add sigopt hpo to transformers. Signed-off-by: Ding, Ke * extend sigopt changes to test code and others.. Signed-off-by: Ding, Ke * Style. * fix style for sigopt integration. Signed-off-by: Ding, Ke * Add necessary information to run unittests on SigOpt. Co-authored-by: Morgan Funtowicz --- .github/workflows/self-nightly-scheduled.yml | 1 + .github/workflows/self-scheduled.yml | 1 + setup.py | 4 +- src/transformers/__init__.py | 2 + src/transformers/dependency_versions_table.py | 1 + src/transformers/integrations.py | 49 +++++++++++++++++ src/transformers/testing_utils.py | 15 +++++- src/transformers/trainer.py | 43 ++++++++++----- src/transformers/trainer_utils.py | 15 ++++++ tests/test_trainer.py | 52 +++++++++++++++++++ 10 files changed, 168 insertions(+), 15 deletions(-) diff --git a/.github/workflows/self-nightly-scheduled.yml b/.github/workflows/self-nightly-scheduled.yml index 7e3a48695c..6f76e9e8a3 100644 --- a/.github/workflows/self-nightly-scheduled.yml +++ b/.github/workflows/self-nightly-scheduled.yml @@ -15,6 +15,7 @@ env: OMP_NUM_THREADS: 16 MKL_NUM_THREADS: 16 PYTEST_TIMEOUT: 600 + SIGOPT_API_TOKEN: ${{ secrets.SIGOPT_API_TOKEN }} jobs: run_all_tests_torch_gpu: diff --git a/.github/workflows/self-scheduled.yml b/.github/workflows/self-scheduled.yml index 3abab64aed..1ecca8f54b 100644 --- a/.github/workflows/self-scheduled.yml +++ b/.github/workflows/self-scheduled.yml @@ -15,6 +15,7 @@ env: OMP_NUM_THREADS: 16 MKL_NUM_THREADS: 16 PYTEST_TIMEOUT: 600 + SIGOPT_API_TOKEN: ${{ secrets.SIGOPT_API_TOKEN }} jobs: run_all_tests_torch_gpu: diff --git a/setup.py b/setup.py index fa4df87acb..6568d83c62 100644 --- a/setup.py +++ b/setup.py @@ -135,6 +135,7 @@ _deps = [ "sagemaker>=2.31.0", "scikit-learn", "sentencepiece>=0.1.91,!=0.1.92", + "sigopt", "soundfile", "sphinx-copybutton", "sphinx-markdown-tables", @@ -248,8 +249,9 @@ extras["deepspeed"] = deps_list("deepspeed") extras["fairscale"] = deps_list("fairscale") extras["optuna"] = deps_list("optuna") extras["ray"] = deps_list("ray[tune]") +extras["sigopt"] = deps_list("sigopt") -extras["integrations"] = extras["optuna"] + extras["ray"] +extras["integrations"] = extras["optuna"] + extras["ray"]+ extras["sigopt"] extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") extras["audio"] = deps_list("soundfile") diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 66415e60cb..75dee596dc 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -130,6 +130,7 @@ _import_structure = { "is_optuna_available", "is_ray_available", "is_ray_tune_available", + "is_sigopt_available", "is_tensorboard_available", "is_wandb_available", ], @@ -1951,6 +1952,7 @@ if TYPE_CHECKING: is_optuna_available, is_ray_available, is_ray_tune_available, + is_sigopt_available, is_tensorboard_available, is_wandb_available, ) diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 488aaa8372..ef39663714 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -53,6 +53,7 @@ deps = { "sagemaker": "sagemaker>=2.31.0", "scikit-learn": "scikit-learn", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", + "sigopt": "sigopt", "soundfile": "soundfile", "sphinx-copybutton": "sphinx-copybutton", "sphinx-markdown-tables": "sphinx-markdown-tables", diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index f3453d0a6e..cb5836f0e2 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -83,6 +83,10 @@ def is_ray_tune_available(): return importlib.util.find_spec("ray.tune") is not None +def is_sigopt_available(): + return importlib.util.find_spec("sigopt") is not None + + def is_azureml_available(): if importlib.util.find_spec("azureml") is None: return False @@ -117,6 +121,10 @@ def hp_params(trial): if isinstance(trial, dict): return trial + if is_sigopt_available(): + if isinstance(trial, dict): + return trial + raise RuntimeError(f"Unknown type for trial {trial.__class__}") @@ -125,6 +133,8 @@ def default_hp_search_backend(): return "optuna" elif is_ray_tune_available(): return "ray" + elif is_sigopt_available(): + return "sigopt" def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: @@ -288,6 +298,45 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR return best_run +def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: + + from sigopt import Connection + + conn = Connection() + proxies = kwargs.pop("proxies", None) + if proxies is not None: + conn.set_proxies(proxies) + + experiment = conn.experiments().create( + name="huggingface-tune", + parameters=trainer.hp_space(None), + metrics=[dict(name="objective", objective=direction, strategy="optimize")], + parallel_bandwidth=1, + observation_budget=n_trials, + project="huggingface", + ) + logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}") + + while experiment.progress.observation_count < experiment.observation_budget: + suggestion = conn.experiments(experiment.id).suggestions().create() + trainer.objective = None + trainer.train(resume_from_checkpoint=None, trial=suggestion) + # If there hasn't been any evaluation during the training loop. + if getattr(trainer, "objective", None) is None: + metrics = trainer.evaluate() + trainer.objective = trainer.compute_objective(metrics) + + values = [dict(name="objective", value=trainer.objective)] + obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values) + logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]") + experiment = conn.experiments(experiment.id).fetch() + + best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0] + best_run = BestRun(best.id, best.value, best.assignments) + + return best_run + + def get_available_reporting_integrations(): integrations = [] if is_azureml_available(): diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 0cc42f7f9c..593bb469b6 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -51,7 +51,7 @@ from .file_utils import ( is_torchaudio_available, is_vision_available, ) -from .integrations import is_optuna_available, is_ray_available +from .integrations import is_optuna_available, is_ray_available, is_sigopt_available SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" @@ -511,6 +511,19 @@ def require_ray(test_case): return test_case +def require_sigopt(test_case): + """ + Decorator marking a test that requires SigOpt. + + These tests are skipped when SigOpt isn't installed. + + """ + if not is_sigopt_available(): + return unittest.skip("test requires SigOpt")(test_case) + else: + return test_case + + def require_soundfile(test_case): """ Decorator marking a test that requires soundfile diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d39a24bf46..21811cd6de 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -40,8 +40,10 @@ from .integrations import ( # isort: split is_fairscale_available, is_optuna_available, is_ray_tune_available, + is_sigopt_available, run_hp_search_optuna, run_hp_search_ray, + run_hp_search_sigopt, ) import numpy as np @@ -231,9 +233,9 @@ class Trainer: A function that instantiates the model to be used. If provided, each call to :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function. - The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be - able to choose different architectures according to hyper parameters (such as layer count, sizes of inner - layers, dropout probabilities etc). + The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to + be able to choose different architectures according to hyper parameters (such as layer count, sizes of + inner layers, dropout probabilities etc). compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`): The function that will be used to compute metrics at evaluation. Must take a :class:`~transformers.EvalPrediction` and return a dictionary string to metric values. @@ -869,6 +871,8 @@ class Trainer: elif self.hp_search_backend == HPSearchBackend.RAY: params = trial params.pop("wandb", None) + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()} for key, value in params.items(): if not hasattr(self.args, key): @@ -883,6 +887,8 @@ class Trainer: setattr(self.args, key, value) if self.hp_search_backend == HPSearchBackend.OPTUNA: logger.info("Trial:", trial.params) + if self.hp_search_backend == HPSearchBackend.SIGOPT: + logger.info(f"SigOpt Assignments: {trial.assignments}") if self.args.deepspeed: # Rebuild the deepspeed config to reflect the updated training parameters from transformers.deepspeed import HfDeepSpeedConfig @@ -1232,7 +1238,7 @@ class Trainer: self.callback_handler.lr_scheduler = self.lr_scheduler self.callback_handler.train_dataloader = train_dataloader self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None - self.state.trial_params = hp_params(trial) if trial is not None else None + self.state.trial_params = hp_params(trial.assignments) if trial is not None else None # This should be the same if the state has been saved but in case the training arguments changed, it's safer # to set this after the load. self.state.max_steps = max_steps @@ -1524,10 +1530,12 @@ class Trainer: if self.hp_search_backend is not None and trial is not None: if self.hp_search_backend == HPSearchBackend.OPTUNA: run_id = trial.number - else: + elif self.hp_search_backend == HPSearchBackend.RAY: from ray import tune run_id = tune.get_trial_id() + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + run_id = trial.id run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" run_dir = os.path.join(self.args.output_dir, run_name) else: @@ -1671,9 +1679,9 @@ class Trainer: **kwargs, ) -> BestRun: """ - Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by - :obj:`compute_objective`, which defaults to a function returning the evaluation loss when no metric is - provided, the sum of all metrics otherwise. + Launch an hyperparameter search using ``optuna`` or ``Ray Tune`` or ``SigOpt``. The optimized quantity is + determined by :obj:`compute_objective`, which defaults to a function returning the evaluation loss when no + metric is provided, the sum of all metrics otherwise. .. warning:: @@ -1686,7 +1694,8 @@ class Trainer: hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`): A function that defines the hyperparameter search space. Will default to :func:`~transformers.trainer_utils.default_hp_space_optuna` or - :func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend. + :func:`~transformers.trainer_utils.default_hp_space_ray` or + :func:`~transformers.trainer_utils.default_hp_space_sigopt` depending on your backend. compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`): A function computing the objective to minimize or maximize from the metrics returned by the :obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`. @@ -1697,8 +1706,8 @@ class Trainer: pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or several metrics. backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`): - The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which - one is installed. If both are installed, will default to optuna. + The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending + on which one is installed. If all are installed, will default to optuna. kwargs: Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For more information see: @@ -1707,6 +1716,7 @@ class Trainer: `__ - the documentation of `tune.run `__ + - the documentation of `sigopt `__ Returns: :class:`transformers.trainer_utils.BestRun`: All the information about the best run. @@ -1718,6 +1728,7 @@ class Trainer: "At least one of optuna or ray should be installed. " "To install optuna run `pip install optuna`." "To install ray run `pip install ray[tune]`." + "To install sigopt run `pip install sigopt`." ) backend = HPSearchBackend(backend) if backend == HPSearchBackend.OPTUNA and not is_optuna_available(): @@ -1726,6 +1737,8 @@ class Trainer: raise RuntimeError( "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`." ) + if backend == HPSearchBackend.SIGOPT and not is_sigopt_available(): + raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.") self.hp_search_backend = backend if self.model_init is None: raise RuntimeError( @@ -1736,8 +1749,12 @@ class Trainer: self.hp_name = hp_name self.compute_objective = default_compute_objective if compute_objective is None else compute_objective - run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray - best_run = run_hp_search(self, n_trials, direction, **kwargs) + backend_dict = { + HPSearchBackend.OPTUNA: run_hp_search_optuna, + HPSearchBackend.RAY: run_hp_search_ray, + HPSearchBackend.SIGOPT: run_hp_search_sigopt, + } + best_run = backend_dict[backend](self, n_trials, direction, **kwargs) self.hp_search_backend = None return best_run diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index ea4a8739d8..3c6a2b85b6 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -198,14 +198,29 @@ def default_hp_space_ray(trial) -> Dict[str, float]: } +def default_hp_space_sigopt(trial): + return [ + {"bounds": {"min": 1e-6, "max": 1e-4}, "name": "learning_rate", "type": "double", "transformamtion": "log"}, + {"bounds": {"min": 1, "max": 6}, "name": "num_train_epochs", "type": "int"}, + {"bounds": {"min": 1, "max": 40}, "name": "seed", "type": "int"}, + { + "categorical_values": ["4", "8", "16", "32", "64"], + "name": "per_device_train_batch_size", + "type": "categorical", + }, + ] + + class HPSearchBackend(ExplicitEnum): OPTUNA = "optuna" RAY = "ray" + SIGOPT = "sigopt" default_hp_space = { HPSearchBackend.OPTUNA: default_hp_space_optuna, HPSearchBackend.RAY: default_hp_space_ray, + HPSearchBackend.SIGOPT: default_hp_space_sigopt, } diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 0e513a30eb..72fda10c8f 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -50,6 +50,7 @@ from transformers.testing_utils import ( require_optuna, require_ray, require_sentencepiece, + require_sigopt, require_tokenizers, require_torch, require_torch_gpu, @@ -1522,3 +1523,54 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase): with ray_start_client_server(): assert ray.util.client.ray.is_connected() self.ray_hyperparameter_search() + + +@require_torch +@require_sigopt +class TrainerHyperParameterSigOptIntegrationTest(unittest.TestCase): + def setUp(self): + args = TrainingArguments(".") + self.n_epochs = args.num_train_epochs + self.batch_size = args.train_batch_size + + def test_hyperparameter_search(self): + class MyTrialShortNamer(TrialShortNamer): + DEFAULTS = {"a": 0, "b": 0} + + def hp_space(trial): + return [ + {"bounds": {"min": -4, "max": 4}, "name": "a", "type": "int"}, + {"bounds": {"min": -4, "max": 4}, "name": "b", "type": "int"}, + ] + + def model_init(trial): + if trial is not None: + a = trial.assignments["a"] + b = trial.assignments["b"] + else: + a = 0 + b = 0 + config = RegressionModelConfig(a=a, b=b, double_output=False) + + return RegressionPreTrainedModel(config) + + def hp_name(trial): + return MyTrialShortNamer.shortname(trial.assignments) + + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = get_regression_trainer( + output_dir=tmp_dir, + learning_rate=0.1, + logging_steps=1, + evaluation_strategy=IntervalStrategy.EPOCH, + save_strategy=IntervalStrategy.EPOCH, + num_train_epochs=4, + disable_tqdm=True, + load_best_model_at_end=True, + logging_dir="runs", + run_name="test", + model_init=model_init, + ) + trainer.hyperparameter_search( + direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="sigopt", n_trials=4 + )