[Deepspeed] ZeRO-Infinity integration plus config revamp (#11418)

* adding Z-inf

* revamp config process

* up version requirement

* wip

* massive rewrite

* cleanup

* cleanup

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* consistent json commas

* act on suggestions

* leave this feature for 0.3.16

* style

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman 2021-04-26 10:40:32 -07:00 committed by GitHub
parent 0661abc545
commit bc2571e61c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 896 additions and 503 deletions

File diff suppressed because it is too large Load Diff

View File

@ -90,7 +90,7 @@ _deps = [
"cookiecutter==1.7.2",
"dataclasses",
"datasets",
"deepspeed>=0.3.14",
"deepspeed>=0.3.15",
"docutils==0.16.0",
"fairscale>0.3",
"faiss-cpu",

View File

@ -7,7 +7,7 @@ deps = {
"cookiecutter": "cookiecutter==1.7.2",
"dataclasses": "dataclasses",
"datasets": "datasets",
"deepspeed": "deepspeed>=0.3.14",
"deepspeed": "deepspeed>=0.3.15",
"docutils": "docutils==0.16.0",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",

View File

@ -19,8 +19,8 @@ import io
import json
import numbers
import os
import sys
import tempfile
import weakref
from copy import deepcopy
from pathlib import Path
@ -269,74 +269,180 @@ def rewrite_logs(d):
return new_d
_is_deepspeed_zero3_enabled = None
def _is_true(config, key):
if config is None:
return False
return bool(config.get(key))
def _set_if_auto(config, key, val):
if config is None:
return
if config.get(key) == "auto":
config[key] = val
class DeepSpeedConfigHF:
"""
This object contains Deepspeed configuration and can be quickly queried for things like zero stage.
We store a ``weakref`` of this object in the module's global to be able to access the config from areas where the
Trainer is not available (e.g. `from_pretrained` and `_get_resized_embeddings`).
The ``DeepSpeedConfigHF`` object is meant to be created during ``TrainingArguments`` object creation and has the
same lifespan as the latter.
"""
def __init__(self, args):
self.config = None
self.stage = 0
self.offload = False
dep_version_check("deepspeed")
self.config_process(args)
# set global weakref object
deepspeed_config_hf_set(self)
def is_zero2(self):
return self.stage == 2
def is_zero3(self):
return self.stage == 3
def is_offload(self):
return self.offload
def config_process(self, args):
"""
1. load json if the ``args.deepspeed`` is a path
2. replace any ``auto`` values in the config with the correct or recommended value
This is done as early as possible, before model is created, to allow ``is_deepspeed_zero3_enabled`` query and
getting to the early deepspeed config object during ``zero.Init()`` which needs whether fp16 is enabled, dtype,
etc.
"""
config_file_or_dict = args.deepspeed
if isinstance(config_file_or_dict, dict):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# modified it, it will not be accepted here again, since `auto` values would have been overriden
config = deepcopy(config_file_or_dict)
elif isinstance(config_file_or_dict, str):
with io.open(config_file_or_dict, "r", encoding="utf-8") as f:
config = json.load(f)
else:
raise ValueError("expecting either a path to a config file or a pre-populated dict")
self.config = config
# DeepSpeed does:
# train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
_set_if_auto(config, "train_micro_batch_size_per_gpu", args.per_device_train_batch_size)
_set_if_auto(config, "gradient_accumulation_steps", args.gradient_accumulation_steps)
_set_if_auto(config, "train_batch_size", train_batch_size)
_set_if_auto(config, "gradient_clipping", args.max_grad_norm)
# zero
config_zero = config.get("zero_optimization", {})
self.stage = config_zero.get("stage", 0)
config_optim = config.get("optimizer", {})
if config_optim != {}:
config_optim_params = config_optim.get("params")
_set_if_auto(config_optim_params, "lr", args.learning_rate)
_set_if_auto(config_optim_params, "betas", [args.adam_beta1, args.adam_beta2])
_set_if_auto(config_optim_params, "eps", args.adam_epsilon)
_set_if_auto(config_optim_params, "weight_decay", args.weight_decay)
config_sched = config.get("scheduler", {})
if config_sched != {}:
config_sched_params = config_sched.get("params")
_set_if_auto(config_sched_params, "warmup_min_lr", 0)
_set_if_auto(config_sched_params, "warmup_max_lr", args.learning_rate)
_set_if_auto(config_sched_params, "warmup_num_steps", args.warmup_steps)
# total_num_steps - will get set in deepspeed_init
# fp16
if args.fp16:
fp16_backend = "apex" if args.fp16_backend == "apex" else "amp"
else:
fp16_backend = None
# amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set
# any here unless the user did the work
config_fp16 = config.get("fp16")
# XXX: at the moment fp16 can't be False, but the fp32 solution is in works - once it's PR'ed and
# merged and a new release is made, delete the next line and uncomment the one after it
_set_if_auto(config_fp16, "enabled", True)
# _set_if_auto(config_fp16, "enabled", fp16_backend == "amp")
# apex: delegates amp work to apex (which needs to be available), but it cannot be used with any
# ZeRO features, so probably best to be avoided.
config_amp = config.get("amp")
_set_if_auto(config_amp, "enabled", fp16_backend == "apex")
_set_if_auto(config_amp, "opt_level", args.fp16_opt_level)
config_zero = config.get("zero_optimization", {})
if self.is_zero2():
self.offload = _is_true(config_zero, "cpu_offload")
elif self.is_zero3():
offload_devices = ["cpu", "nvme"]
if config_zero.get("offload_optimizer", {}).get("device") in offload_devices:
self.offload = True
if config_zero.get("offload_param", {}).get("device") in offload_devices:
self.offload = True
def config_finalize(self, args, model, num_training_steps):
"""
This stage is run after we have the model and know num_training_steps.
Now we we can complete the configuration process.
"""
config = self.config
# zero
config_zero = config.get("zero_optimization", {})
if self.is_zero3():
# automatically assign the optimal config values based on model config
hidden_size = model.config.hidden_size
_set_if_auto(config_zero, "reduce_bucket_size", hidden_size * hidden_size)
_set_if_auto(config_zero, "stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size)
_set_if_auto(config_zero, "stage3_param_persistence_threshold", 10 * hidden_size)
# scheduler
config_sched = config.get("scheduler", {})
config_sched_params = config_sched.get("params", {})
_set_if_auto(config_sched_params, "total_num_steps", num_training_steps)
# keep the config object global to be able to access it anywhere during TrainingArguments life-cycle
_deepspeed_config_hf_weak_ref = None
def deepspeed_config_hf_set(deepspeed_config_hf_obj):
# this is a special weakref global object to allow us to get to Deepspeed config from APIs
# that don't have an easy way to get to the Deepspeed config outside of the Trainer domain.
global _deepspeed_config_hf_weak_ref
# will go away automatically when DeepSpeedConfigHF is destroyed (when TrainingArguments is destroyed)
_deepspeed_config_hf_weak_ref = weakref.ref(deepspeed_config_hf_obj)
def is_deepspeed_zero3_enabled():
"""
This function answers to the question of whether DeepSpeed is going to be used and run using ZeRO Stage 3.
It includes an auto-discovery method, see comments in the code for details.
Returns: ``True`` if either it was explicitly enabled via ``deepspeed_zero3_enable(True)`` or the auto-detector was
able to derive that the ``Trainer`` will be running via DeepSpeed ZeRO stage 3.
"""
global _is_deepspeed_zero3_enabled
if _is_deepspeed_zero3_enabled is None:
_is_deepspeed_zero3_enabled = False
# Try to auto-discover if we are about to use DeepSpeed with ZeRO3 enabled. This will only
# work for scripts using cli to pass --deepspeed ds_config.json. If cmd args aren't used,
# then to get the model efficiently loaded across multiple-gpus one has to explicitly call
# is_deepspeed_zero3_enabled(True) **before** instantiating a model object
if "--deepspeed" in sys.argv:
idx = sys.argv.index("--deepspeed")
ds_config = sys.argv[idx + 1]
if not os.path.exists(ds_config):
raise ValueError("--deepspeed requires a valid path to a config file")
config = deepspeed_parse_config(ds_config)
if (
"zero_optimization" in config
and "stage" in config["zero_optimization"]
and config["zero_optimization"]["stage"] == 3
):
_is_deepspeed_zero3_enabled = True
return _is_deepspeed_zero3_enabled
def deepspeed_zero3_enable(enable=True):
"""
``is_deepspeed_zero3_enabled()`` tries to derive automatically if DeepSpeed ZeRO 3 is going to be used by looking
at ``sys.argv`` which may or may contain information about where to find the DeepSpeed config if any.
This function allows for explicit enabling/disabling of this global flag.
Args:
enable: if set to ``True`` will make ``is_deepspeed_zero3_enabled()`` return ``True``
"""
global _is_deepspeed_zero3_enabled
_is_deepspeed_zero3_enabled = enable
def deepspeed_parse_config(ds_config):
"""
If ``ds_config`` isn't already a dict, read it from the config file.
If it's already a dict, return a copy of it, so that we can freely modify it.
"""
dep_version_check("deepspeed")
if isinstance(ds_config, dict):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# modified it, it will not be accepted here again, since some config params must be not set by users
config = deepcopy(ds_config)
elif isinstance(ds_config, str):
with io.open(ds_config, "r", encoding="utf-8") as f:
config = json.load(f)
if _deepspeed_config_hf_weak_ref is not None and _deepspeed_config_hf_weak_ref() is not None:
return _deepspeed_config_hf_weak_ref().is_zero3()
else:
raise ValueError("expecting either a path to a config file or a pre-populated dict")
return False
return config
def deepspeed_config():
if _deepspeed_config_hf_weak_ref is not None and _deepspeed_config_hf_weak_ref() is not None:
return _deepspeed_config_hf_weak_ref().config
else:
return None
def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
@ -355,41 +461,16 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
"""
import deepspeed
args = trainer.args
model = trainer.model
config = deepspeed_parse_config(args.deepspeed)
deepspeed_config_hf = trainer.args.deepspeed_config_hf
deepspeed_config_hf.config_finalize(trainer.args, model, num_training_steps)
# The following code translates relevant trainer's cl args into the DS config
# First to ensure that there is no mismatch between cl args values and presets in the config
# file, ask to not set in ds config file:
# - "train_batch_size",
# - "train_micro_batch_size_per_gpu",
# - "gradient_accumulation_steps"
bs_keys = ["train_batch_size", "train_micro_batch_size_per_gpu"]
if len([x for x in bs_keys if x in config.keys()]):
raise ValueError(
f"Do not include {bs_keys} entries in the ds config file, as they will be set via --per_device_train_batch_size or its default"
)
if "gradient_accumulation_steps" in config.keys():
raise ValueError(
"Do not include gradient_accumulation_steps entries in the ds config file, as they will be set via --gradient_accumulation_steps or its default"
)
# DeepSpeed does:
# train_batch_size = n_gpus * train_micro_batch_size_per_gpu * gradient_accumulation_steps
# therefore we just need to set:
config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
if "gradient_clipping" in config:
logger.info("Keeping the `gradient_clipping` config intact, ignoring any gradient clipping-specific cl args")
else: # override only if the ds config doesn't already have this section
config["gradient_clipping"] = args.max_grad_norm
# resume config update - some bits like `model` and `num_training_steps` only become available during train
config = deepspeed_config_hf.config
# Optimizer + Scheduler
# Currently support combos:
# Currently supported combos:
# 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Yes
# 3. DS scheduler + HF optimizer: Yes
@ -402,36 +483,16 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
# 4. HF scheduler + DS optimizer: No
optimizer = None
if "optimizer" in config:
logger.info("Updating the `scheduler` config with other command line arguments")
# to avoid inconsistent values of lr and warm up steps the command line args override config
params = dict(
lr=args.learning_rate,
betas=[args.adam_beta1, args.adam_beta2],
eps=args.adam_epsilon,
weight_decay=args.weight_decay,
)
for k, v in params.items():
if k in config["optimizer"]["params"]:
logger.info(f"setting optimizer.params.{k} to {v}")
config["optimizer"]["params"][k] = v
else: # override only if the ds config doesn't already have this section
if (
"zero_optimization" in config
and "cpu_offload" in config["zero_optimization"]
and config["zero_optimization"]["cpu_offload"] is True
):
if "optimizer" not in config:
if deepspeed_config_hf.is_offload():
raise ValueError("ZeRO Offload can only work with DeepSpeed optimizers")
else:
# ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
# But trainer uses AdamW by default.
# To use other optimizers so using a different scheduler requires voiding warranty with: `zero_allow_untested_optimizer`
trainer.create_optimizer()
optimizer = trainer.optimizer
# flag that this is non-native optimizer
config["zero_allow_untested_optimizer"] = True
# ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
# But trainer uses AdamW by default.
trainer.create_optimizer()
optimizer = trainer.optimizer
# To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer`
config["zero_allow_untested_optimizer"] = True
# DS schedulers (deepspeed/runtime/lr_schedules.py):
#
@ -442,25 +503,7 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
# WarmupLR | constant_with_warmup | get_constant_schedule_with_warmup | w/ warmup_min_lr=0
# WarmupDecayLR| linear | get_linear_schedule_with_warmup |
lr_scheduler = None
if "scheduler" in config:
logger.info("Updating the `scheduler` config with other command line arguments")
# the user won't easily know the correct num_training_steps should they use WarmupDecayLR,
# so let's set it to the correct value
if config["scheduler"]["type"] == "WarmupDecayLR":
logger.info(f"setting scheduler.params.total_num_steps to {num_training_steps}")
config["scheduler"]["params"]["total_num_steps"] = num_training_steps
# to avoid inconsistent values of lr and warmup steps the command line args override config
params = dict(
warmup_max_lr=args.learning_rate,
warmup_num_steps=args.warmup_steps,
)
for k, v in params.items():
if k in config["scheduler"]["params"]:
logger.info(f"setting scheduler.params.{k} to {v}")
config["scheduler"]["params"][k] = v
else: # override only if the ds config doesn't already have this section
if "scheduler" not in config:
if "optimizer" in config:
# to make this option work, we need to init DS optimizer first, then init HS scheduler,
# then pass the HS scheduler to DS init, which is not possible at the moment
@ -469,43 +512,6 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
trainer.create_scheduler(num_training_steps=num_training_steps)
lr_scheduler = trainer.lr_scheduler
# fp16
if trainer.fp16_backend is not None:
# Deepspeed has 2 possible fp16 config entries:
# - `fp16`: for the native amp - it has a bunch of optional params but we won't set any here unless the user did the work
# - `amp`: which delegates amp work to apex (which needs to be available), but it cannot be used with any ZeRO features, so probably best to be avoided.
if trainer.fp16_backend == "apex":
if "amp" in config:
logger.info("Keeping the `amp` config intact, ignoring any amp-specific cl args")
else:
config["amp"] = {
"enabled": True,
"opt_level": args.fp16_opt_level,
}
elif trainer.fp16_backend == "amp":
if "fp16" in config:
logger.info("Keeping the `fp16` config intact, ignoring any fp16-specific cl args")
else:
config["fp16"] = {
"enabled": True,
}
# zero
if "zero_optimization" in config:
zero = config["zero_optimization"]
# now we know for sure if zero3 is enabled
deepspeed_zero3_enable(zero.get("stage") == 3)
# automatically assign the optimal config values based on model config
hidden_size = model.config.hidden_size
if zero.get("reduce_bucket_size") == 0:
zero["reduce_bucket_size"] = hidden_size * hidden_size
if zero.get("stage3_prefetch_bucket_size") == 0:
zero["stage3_prefetch_bucket_size"] = 0.9 * hidden_size * hidden_size
if zero.get("stage3_param_persistence_threshold") == 0:
zero["stage3_param_persistence_threshold"] = 10 * hidden_size
# keep for quick debug:
# from pprint import pprint; pprint(config)

View File

@ -1122,7 +1122,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model to avoid the overhead in time and memory copying it on CPU or each GPU first
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
# XXX: param_dict will be added in deepspeed==0.3.16 and probably replaced by deepspeed_config
# with deepspeed.zero.Init(param_dict=deepspeed_config()):
with deepspeed.zero.Init():
model = cls(config, *model_args, **model_kwargs)
else:

View File

@ -70,9 +70,6 @@ class TrainingArguments:
<https://docs.python.org/3/library/argparse.html#module-argparse>`__ arguments that can be specified on the command
line.
Parameters:
output_dir (:obj:`str`):
The output directory where the model predictions and checkpoints will be written.
@ -625,6 +622,14 @@ class TrainingArguments:
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
if self.deepspeed:
# - must be run very last in arg parsing, since it will use a lot of these settings.
# - must be run before the model is created.
from transformers.integrations import DeepSpeedConfigHF
# will be used later by the Trainer (leave self.deepspeed unmodified in case a user relies on it not to be modified)
self.deepspeed_config_hf = DeepSpeedConfigHF(self)
def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)

View File

@ -1,6 +1,6 @@
{
"fp16": {
"enabled": true,
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
@ -8,6 +8,25 @@
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
@ -19,25 +38,10 @@
"cpu_offload": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 3e-5,
"betas": [0.8, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 500
}
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@ -1,6 +1,6 @@
{
"fp16": {
"enabled": true,
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
@ -8,41 +8,50 @@
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 3,
"cpu_offload": true,
"cpu_offload_params": true,
"cpu_offload_use_pin_memory" : true,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e14,
"reduce_bucket_size": 0,
"stage3_prefetch_bucket_size": 0,
"stage3_param_persistence_threshold": 0,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 3e-5,
"betas": [0.8, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 500
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e14,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@ -42,7 +42,7 @@ with ExtendSysPath(f"{bindir}/.."):
from test_trainer import TrainerIntegrationCommon # noqa
if is_torch_available():
from test_trainer import get_regression_trainer # noqa
from test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa
set_seed(42)
@ -66,6 +66,10 @@ def require_deepspeed(test_case):
return test_case
if is_deepspeed_available():
from deepspeed.utils import logger as deepspeed_logger # noqa
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled # noqa
ZERO2 = "zero2"
ZERO3 = "zero3"
stages = [ZERO2, ZERO3]
@ -115,12 +119,6 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f:
self.ds_config_dict[ZERO3] = json.load(f)
def tearDown(self):
# XXX: Fixme - this is a temporary band-aid since this global variable impacts other tests
import transformers
transformers.integrations._is_deepspeed_zero3_enabled = None
def get_config_dict(self, stage):
"""As the tests modify the dict, always make a copy"""
config = deepcopy(self.ds_config_dict[stage])
@ -173,25 +171,65 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero2_dict)
with self.assertRaises(Exception) as context:
trainer.train()
self.assertTrue("HF scheduler + DeepSpeed optimizer combination is not possible" in str(context.exception))
self.assertTrue(
"HF scheduler + DeepSpeed optimizer combination is not possible" in str(context.exception),
f"got exception: {context.exception}",
)
def test_hf_optimizer_with_offload(self):
# must not allow non-DS optimizer when using ZERO-offload
def test_stage3_nvme_offload(self):
with mockenv_context(**self.dist_env_1_gpu):
ds_config_zero2_dict = self.get_config_dict(ZERO2)
del ds_config_zero2_dict["optimizer"] # force default HF Trainer optimizer
ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = True
# sanity check - should the default config change
assert (
"cpu_offload" in ds_config_zero2_dict["zero_optimization"]
and ds_config_zero2_dict["zero_optimization"]["cpu_offload"] is True
), "ensure the config is set up correctly"
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero2_dict)
with self.assertRaises(Exception) as context:
# this actually doesn't have to be on NVMe, any storage will do since this test only
# runs a simple check that we can use some directory as if it were NVMe
nvme_path = self.get_auto_remove_tmp_dir()
nvme_config = dict(device="nvme", nvme_path=nvme_path)
ds_config_zero3_dict = self.get_config_dict(ZERO3)
ds_config_zero3_dict["zero_optimization"]["offload_optimizer"] = nvme_config
ds_config_zero3_dict["zero_optimization"]["offload_param"] = nvme_config
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero3_dict)
with CaptureLogger(deepspeed_logger) as cs:
trainer.train()
self.assertTrue("ZeRO Offload can only work with DeepSpeed optimizers" in str(context.exception))
self.assertIn("DeepSpeed info", cs.out, "expected DeepSpeed logger output but got none")
# --- These tests need to run on both zero stages --- #
@parameterized.expand(stages)
def test_fp32(self, stage):
ds_config_dict = self.get_config_dict(stage)
ds_config_dict["fp16"]["enabled"] = False # force non-fp16 mode
# XXX: do we go via from_pretrained in zero 3 here? need to test zero.Init(dtype=torch.float)
# XXX: rewrite this test once fp32 is supported by DeepSpeed
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_dict)
with self.assertRaises(Exception) as context:
trainer.train()
self.assertIn(
"ZeRO is only supported if fp16 is enabled",
str(context.exception),
f"got exception: {context.exception}",
)
@parameterized.expand(stages)
def test_hf_optimizer_with_offload(self, stage):
# must not allow non-DS optimizer when using ZERO-offload
ds_config_dict = self.get_config_dict(stage)
del ds_config_dict["optimizer"] # force default HF Trainer optimizer
# force cpu offload
if stage == "stage2":
ds_config_dict["zero_optimization"]["cpu_offload"] = True
elif stage == "stage3":
ds_config_dict["zero_optimization"]["offload_optimizer"]["device"] = "cpu"
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_dict)
with self.assertRaises(Exception) as context:
trainer.train()
self.assertIn(
"ZeRO Offload can only work with DeepSpeed optimizers",
str(context.exception),
f"got exception: {context.exception}",
)
@parameterized.expand(stages)
def test_fake_notebook_no_launcher(self, stage):
# this setup emulates a notebook where a launcher needs to be emulated by hand
@ -199,14 +237,12 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
# note that unittest resets sys.stdout each test, so `CaptureStd` will work here to capture
# DeepSpeed log if this test happens to run first in this pytest worker. But it will fail if
# it's run not as a first test as `sys.stdout` will no longer be the same. So we either have
# to reset `logger.handlers[0].setStream(sys.stdout)` or directly capture from the logger.
from deepspeed.utils import logger
with CaptureLogger(logger) as cs:
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=self.ds_config_file[stage])
# to reset `deepspeed_logger.handlers[0].setStream(sys.stdout)` or directly capture from the deepspeed_logger.
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=self.ds_config_file[stage])
with CaptureLogger(deepspeed_logger) as cs:
trainer.train()
assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none"
self.assertIn("DeepSpeed info", cs.out, "expected DeepSpeed logger output but got none")
@parameterized.expand(stages)
def test_early_get_last_lr(self, stage):
@ -425,6 +461,38 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
def test_config_object(self):
# test that we can switch from zero2 to zero3 in the same process for example
# test is_zero, etc.
output_dir = self.get_auto_remove_tmp_dir()
kwargs = dict(output_dir=output_dir, train_len=8)
with mockenv_context(**self.dist_env_1_gpu):
ds_config_zero3_dict = self.get_config_dict("zero3")
ds_config_zero2_dict = self.get_config_dict("zero2")
trainer = get_regression_trainer(deepspeed=ds_config_zero3_dict, **kwargs)
self.assertTrue(is_deepspeed_zero3_enabled())
# test we can repeat that and with train this time
trainer = get_regression_trainer(deepspeed=ds_config_zero3_dict, **kwargs)
trainer.train()
self.assertTrue(is_deepspeed_zero3_enabled())
# test zero3 is disabled
trainer = get_regression_trainer(deepspeed=ds_config_zero2_dict, **kwargs)
self.assertFalse(is_deepspeed_zero3_enabled())
# check config obj
config = deepspeed_config()
self.assertTrue(bool(config), "Deepspeed config should be accessible")
del trainer
# now weakref should gc the global and we shouldn't get anything here
config = deepspeed_config()
self.assertFalse(is_deepspeed_zero3_enabled())
self.assertFalse(bool(config), "Deepspeed config should not be accessible")
@slow
@require_deepspeed
@ -557,6 +625,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
--adafactor
--source_lang en
--target_lang ro
--report_to none
""".split()
args.extend(["--source_prefix", '"translate English to Romanian: "'])
@ -626,6 +695,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
--num_train_epochs 1
--warmup_steps 8
--block_size 128
--report_to none
""".split()
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()

View File

@ -213,16 +213,21 @@ if is_torch_available():
label_names = kwargs.get("label_names", None)
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
if pretrained:
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
model = RegressionPreTrainedModel(config)
model_init = kwargs.pop("model_init", None)
if model_init is not None:
model = None
else:
model = RegressionModel(a=a, b=b, double_output=double_output)
if pretrained:
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
model = RegressionPreTrainedModel(config)
else:
model = RegressionModel(a=a, b=b, double_output=double_output)
compute_metrics = kwargs.pop("compute_metrics", None)
data_collator = kwargs.pop("data_collator", None)
optimizers = kwargs.pop("optimizers", (None, None))
output_dir = kwargs.pop("output_dir", "./regression")
model_init = kwargs.pop("model_init", None)
args = RegressionTrainingArguments(output_dir, a=a, b=b, **kwargs)
return Trainer(