[deepspeed] zero inference (#14253)

* [deepspeed] zero inference

* only z3 makes sense for inference

* fix and style

* docs

* rework

* fix test

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* responding to suggestions

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman 2021-11-23 14:09:15 -08:00 committed by GitHub
parent 69e16abf98
commit 956a483173
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 149 additions and 38 deletions

View File

@ -46,6 +46,20 @@ won't be possible on a single GPU.
parts of DeepSpeed like ``zero.Init`` for ZeRO stage 3 and higher. To tap into this feature read the docs on
:ref:`deepspeed-non-trainer-integration`.
What is integrated:
Training:
1. DeepSpeed ZeRO training supports the full ZeRO stages 1, 2 and 3 with ZeRO-Infinity (CPU and NVME offload).
Inference:
1. DeepSpeed ZeRO Inference supports ZeRO stage 3 with ZeRO-Infinity. It uses the same ZeRO protocol as training, but
it doesn't use an optimizer and a lr scheduler and only stage 3 is relevant. For more details see:
:ref:`deepspeed-zero-inference`.
There is also DeepSpeed Inference - this is a totally different technology which uses Tensor Parallelism instead of
ZeRO (coming soon).
@ -1628,6 +1642,47 @@ larger multi-dimensional shape, this means that the parameter is partitioned and
.. _deepspeed-zero-inference:
ZeRO Inference
=======================================================================================================================
ZeRO Inference uses the same config as ZeRO-3 Training. You just don't need the optimizer and scheduler sections. In
fact you can leave these in the config file if you want to share the same one with the training. They will just be
ignored.
Otherwise you just need to pass the usual :class:`~transformers.TrainingArguments` arguments. For example:
.. code-block:: bash
deepspeed --num_gpus=2 your_program.py <normal cl args> --do_eval --deepspeed ds_config.json
The only important thing is that you need to use a ZeRO-3 configuration, since ZeRO-2 provides no benefit whatsoever
for the inference as only ZeRO-3 performs sharding of parameters, whereas ZeRO-1 shards gradients and optimizer states.
Here is an example of running ``run_translation.py`` under DeepSpeed deploying all available GPUs:
.. code-block:: bash
deepspeed examples/pytorch/translation/run_translation.py \
--deepspeed tests/deepspeed/ds_config_zero3.json \
--model_name_or_path t5-small --output_dir output_dir \
--do_eval --max_eval_samples 50 --warmup_steps 50 \
--max_source_length 128 --val_max_target_length 128 \
--overwrite_output_dir --per_device_eval_batch_size 4 \
--predict_with_generate --dataset_config "ro-en" --fp16 \
--source_lang en --target_lang ro --dataset_name wmt16 \
--source_prefix "translate English to Romanian: "
Since for inference there is no need for additional large memory used by the optimizer states and the gradients you
should be able to fit much larger batches and/or sequence length onto the same hardware.
Additionally DeepSpeed is currently developing a related product called Deepspeed-Inference which has no relationship
to the ZeRO technology, but instead uses tensor parallelism to scale models that can't fit onto a single GPU. This is a
work in progress and we will provide the integration once that product is complete.
Filing Issues
=======================================================================================================================

View File

@ -97,7 +97,7 @@ _deps = [
"cookiecutter==1.7.2",
"dataclasses",
"datasets",
"deepspeed>=0.5.3",
"deepspeed>=0.5.7",
"docutils==0.16.0",
"fairscale>0.3",
"faiss-cpu",

View File

@ -111,6 +111,29 @@ class HfDeepSpeedConfig:
return default
return config.get(ds_key, default)
def del_config_sub_tree(self, ds_key_long, must_exist=False):
"""
Deletes a sub-section of the config file if it's found.
Unless ``must_exist`` is :obj:`True` the section doesn't have to exist.
"""
config = self.config
# find the config node of interest if it exists
nodes = ds_key_long.split(".")
for node in nodes:
parent_config = config
config = config.get(node)
if config is None:
if must_exist:
raise ValueError(f"Can't find {ds_key_long} entry in the config: {self.config}")
else:
return
# if found remove it
if parent_config is not None:
parent_config.pop(node)
def is_true(self, ds_key_long):
"""
Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to
@ -280,30 +303,10 @@ def deepspeed_config():
return None
def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps):
"""
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
If ``resume_from_checkpoint`` was passed then an attempt to resume from a previously saved checkpoint will be made.
Args:
trainer: Trainer object
num_training_steps: per single gpu
resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load
Returns: model, optimizer, lr_scheduler
A convenience wrapper that deals with optimizer and lr scheduler configuration.
"""
import deepspeed
from deepspeed.utils import logger as ds_logger
model = trainer.model
args = trainer.args
hf_deepspeed_config = args.hf_deepspeed_config
hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps)
# resume config update - some bits like `model` and `num_training_steps` only become available during train
config = hf_deepspeed_config.config
# Optimizer + Scheduler
@ -351,13 +354,54 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
else:
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
# keep for quick debug:
# from pprint import pprint; pprint(config)
return optimizer, lr_scheduler
# set the Deepspeed log level consistent with the trainer
def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False):
"""
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
If ``resume_from_checkpoint`` was passed then an attempt to resume from a previously saved checkpoint will be made.
Args:
trainer: Trainer object
num_training_steps: per single gpu
resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load
inference: launch in inference mode (no optimizer and no lr scheduler)
Returns: model, optimizer, lr_scheduler
"""
import deepspeed
from deepspeed.utils import logger as ds_logger
model = trainer.model
args = trainer.args
# resume config update - some bits like `model` and `num_training_steps` only become available during train
hf_deepspeed_config = args.hf_deepspeed_config
hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps)
config = hf_deepspeed_config.config
# set the Deepspeed log level consistent with the Trainer
ds_logger.setLevel(args.get_process_log_level())
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
if inference:
# only Z3 makes sense for the inference
if not hf_deepspeed_config.is_zero3():
raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config")
# in case the training config is re-used for inference
hf_deepspeed_config.del_config_sub_tree("optimizer")
hf_deepspeed_config.del_config_sub_tree("lr_scheduler")
optimizer, lr_scheduler = None, None
model_parameters = None
else:
optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
# keep for quick debug:
# from pprint import pprint; pprint(config)
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,

View File

@ -8,7 +8,7 @@ deps = {
"cookiecutter": "cookiecutter==1.7.2",
"dataclasses": "dataclasses",
"datasets": "datasets",
"deepspeed": "deepspeed>=0.5.3",
"deepspeed": "deepspeed>=0.5.7",
"docutils": "docutils==0.16.0",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",

View File

@ -2229,15 +2229,12 @@ class Trainer:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually
deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
deepspeed_engine, _, _ = deepspeed_init(
self, num_training_steps=0, resume_from_checkpoint=None, inference=True
)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
# XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
# for example the Z3-optimizer is a must for zero3 to work even for inference - what we
# don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
deepspeed_engine.optimizer.optimizer = None
deepspeed_engine.lr_scheduler = None
model = self._wrap_model(self.model, training=False)

View File

@ -697,11 +697,10 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
def test_basic_distributed(self, stage):
self.run_and_check(stage=stage, distributed=True)
@parameterized.expand(stages)
def test_do_eval_no_train(self, stage):
# we should not fail if train is skipped
def test_do_eval_no_train(self):
# testing only zero3 since zero2 makes no sense with inference
self.run_and_check(
stage=stage,
stage=ZERO3,
eval_steps=1,
distributed=False,
do_train=False,
@ -755,6 +754,22 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
self.do_checks(output_dir, do_train=do_train, do_eval=do_eval)
@require_torch_multi_gpu
@parameterized.expand(["fp16", "fp32"])
def test_inference(self, dtype):
# this is just inference, so no optimizer should be loaded
# it only works for z3 (makes no sense with z1-z2)
fp16 = True if dtype == "fp16" else False
self.run_and_check(
stage=ZERO3,
model_name=T5_TINY,
distributed=True,
do_train=False,
do_eval=True,
quality_checks=False,
fp16=fp16,
)
def do_checks(self, output_dir, do_train=True, do_eval=True, quality_checks=True):
if do_train: