Enable torchdynamo with torch_tensorrt(fx path) (#17765)

* enable fx2trt

* Update perf_train_gpu_one.mdx

* Update perf_train_gpu_one.mdx

* add lib check

* update

* format

* update

* fix import check

* fix isort

* improve doc

* refactor ctx manager

* fix isort

* black format

* isort fix

* fix format

* update args

* update black

* cleanups

* Update perf_train_gpu_one.mdx

* code refactor

* code refactor to init

* remove redundancy

* isort

* replace self.args with args

Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Wei 2022-07-13 09:43:28 -07:00 committed by GitHub
parent 37aeb5787a
commit 7ea6ccc2b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 88 additions and 22 deletions

View File

@ -11,7 +11,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
# Efficient Training on a Single GPU
This guide focuses on training large models efficiently on a single GPU. These approaches are still valid if you have access to a machine with multiple GPUs but you will also have access to additional methods outlined in the [multi-GPU section](perf_train_gpu_many).
This guide focuses on training large models efficiently on a single GPU. These approaches are still valid if you have access to a machine with multiple GPUs but you will also have access to additional methods outlined in the [multi-GPU section](perf_train_gpu_many).
In this section we have a look at a few tricks to reduce the memory footprint and speed up training for large models and how they are integrated in the [`Trainer`] and [🤗 Accelerate](https://huggingface.co/docs/accelerate/). Each method can improve speed or memory usage which is summarized in the table below:
@ -367,7 +367,7 @@ Samples/second: 10.09
GPU memory occupied: 7275 MB.
```
We can see that with these tweaks we use about half the GPU memory as at the beginning while also being slightly faster.
We can see that with these tweaks we use about half the GPU memory as at the beginning while also being slightly faster.
### BF16
If you have access to a Ampere or newer hardware you can use bf16 for your training and evaluation. While bf16 has a worse precision than fp16, it has a much much bigger dynamic range. Therefore, if in the past you were experiencing overflow issues while training the model, bf16 will prevent this from happening most of the time. Remember that in fp16 the biggest number you can have is `65535` and any number above that will overflow. A bf16 number can be as large as `3.39e+38` (!) which is about the same as fp32 - because both have 8-bits used for the numerical range.
@ -394,7 +394,7 @@ Like all cases with reduced precision this may or may not be satisfactory for yo
If you're already using fp16 or bf16 mixed precision it may help with the throughput as well.
You can enable this mode in the 🤗 Trainer with:
You can enable this mode in the 🤗 Trainer with:
```python
TrainingArguments(tf32=True)
```
@ -654,7 +654,7 @@ https://github.com/huggingface/transformers/blob/master/src/transformers/trainer
## Choice of GPU
Sometimes, even when applying all the above tweaks the throughput on a given GPU might still not be good enough. One easy solution is to change the type of GPU. For example switching from let's say a K80 (which you typically get on Google Colab) to a fancier GPU such as the V100 or A100. Although they are more expensive they are usually more cost effective than cheaper GPUs due to their larger memory and faster architecture.
Sometimes, even when applying all the above tweaks the throughput on a given GPU might still not be good enough. One easy solution is to change the type of GPU. For example switching from let's say a K80 (which you typically get on Google Colab) to a fancier GPU such as the V100 or A100. Although they are more expensive they are usually more cost effective than cheaper GPUs due to their larger memory and faster architecture.
Now, let's take a step back and discuss what we should optimize for when scaling the training of large models.
@ -718,3 +718,15 @@ For some applications, such as pretraining large language models, applying all t
Another use case for training on many GPUs is if the model does not fit on a single GPU with all the mentioned tricks. There are still more methods we can apply although life starts to get a bit more complicated. This usually involves some form of pipeline or tensor parallelism where the model itself is distributed across several GPUs. One can also make use of DeepSpeed which implements some of these parallelism strategies along with some more optimization to reduce the memory footprint such as partitioning the optimizer states. You can read more about this in the ["Multi-GPU training" section](perf_train_gpu_many).
## Inference with torchdynamo
TorchDynamo is a new tracer that uses Pythons frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. One solution is using the [TensorRT](https://developer.nvidia.com/tensorrt) or NVFuser as backend. You can choose one option below for performance boost.
```
TrainingArguments(torchdynamo="eager") #enable eager model GPU. No performance boost
TrainingArguments(torchdynamo="nvfuser") #enable nvfuser
TrainingArguments(torchdynamo="fx2trt") #enable tensorRT fp32
TrainingArguments(torchdynamo="fx2trt-f16") #enable tensorRT fp16
```
This feature involves 3 different libraries. To install them, please follow the instructions below:
- [Torchdynamo installation](https://github.com/pytorch/torchdynamo#requirements-and-setup)
- [Functorch installation](https://github.com/pytorch/functorch#install)
- [Torch-TensorRT(FX) installation](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst#installation)

View File

@ -71,6 +71,7 @@ from .utils import (
is_torch_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
@ -499,6 +500,11 @@ def require_torchdynamo(test_case):
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
def require_torch_tensorrt_fx(test_case):
"""Decorator marking a test that requires Torch-TensorRT FX"""
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)

View File

@ -141,6 +141,7 @@ from .utils import (
is_ipex_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tensorrt_fx_available,
is_torch_tpu_available,
is_torchdynamo_available,
logging,
@ -598,6 +599,35 @@ class Trainer:
# very last
self._memory_tracker.stop_and_update_metrics()
# torchdynamo
if args.torchdynamo:
if not is_torchdynamo_available():
raise RuntimeError("Torchdynamo is not installed.")
import torchdynamo
from torchdynamo.optimizations import backends
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
def get_ctx():
# Normal
if args.torchdynamo == "eager":
return torchdynamo.optimize("eager")
elif args.torchdynamo == "nvfuser":
return torchdynamo.optimize(aot_autograd_speedup_strategy)
# TensorRT
if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
if not is_torch_tensorrt_fx_available():
raise RuntimeError("Torch-TensorRT FX path is not installed.")
if args.torchdynamo == "fx2trt-fp16":
return torchdynamo.optimize(backends.fx2trt_compiler_fp16)
elif args.torchdynamo == "fx2trt":
return torchdynamo.optimize(backends.fx2trt_compiler)
else:
raise RuntimeError(f"Torchdynamo backend {args.torchdynamo} is not supported.")
self.ctx_manager_torchdynamo = get_ctx()
else:
self.ctx_manager_torchdynamo = contextlib.nullcontext()
def add_callback(self, callback):
"""
Add a callback to the current list of [`~transformer.TrainerCallback`].
@ -2291,16 +2321,7 @@ class Trainer:
"""
A helper wrapper that creates an appropriate context manager for `torchdynamo`.
"""
ctx_manager = contextlib.nullcontext()
if is_torchdynamo_available():
import torchdynamo
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
if self.args.torchdynamo == "eager":
ctx_manager = torchdynamo.optimize("eager")
elif self.args.torchdynamo == "nvfuser":
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
return ctx_manager
return self.ctx_manager_torchdynamo
def autocast_smart_context_manager(self):
"""

View File

@ -935,7 +935,7 @@ class TrainingArguments:
" are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
" nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
),
"choices": ["eager", "nvfuser"],
"choices": ["eager", "nvfuser", "fx2trt", "fx2trt-fp16"],
},
)
ray_scope: Optional[str] = field(

View File

@ -132,6 +132,7 @@ from .import_utils import (
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_onnx_dict_inputs_support_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,

View File

@ -418,6 +418,12 @@ def is_torchdynamo_available():
return importlib.util.find_spec("torchdynamo") is not None
def is_torch_tensorrt_fx_available():
if importlib.util.find_spec("torch_tensorrt") is None:
return False
return importlib.util.find_spec("torch_tensorrt.fx") is not None
def is_datasets_available():
return _datasets_available

View File

@ -62,6 +62,7 @@ from transformers.testing_utils import (
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_tensorrt_fx,
require_torch_tf32,
require_torch_up_to_2_gpus,
require_torchdynamo,
@ -1796,6 +1797,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
@require_torch_non_multi_gpu
@require_torchdynamo
@require_torch_tensorrt_fx
def test_torchdynamo_full_eval(self):
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
n_gpus = get_gpu_count()
@ -1824,6 +1826,21 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
# 4. TorchDynamo fx2trt
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
# 5. TorchDynamo fx2trt-fp16
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
# fp16 has accuracy accuracy degradation
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
@require_torch_non_multi_gpu
@require_torchdynamo
def test_torchdynamo_memory(self):
@ -1849,7 +1866,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
mod = MyModule()
# 1. Default - without TorchDynamo
# 1. without TorchDynamo (eager baseline)
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
a.grad = None
trainer = CustomTrainer(model=mod)
@ -1857,16 +1874,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
for _ in range(10):
orig_loss = trainer.training_step(mod, {"x": a})
torch.cuda.reset_peak_memory_stats()
orig_loss = trainer.training_step(mod, {"x": a})
orig_peak_mem = torch.cuda.max_memory_allocated()
del trainer
# Reset the peak for another measurement
# resets
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
orig_loss = trainer.training_step(mod, {"x": a})
orig_peak_mem = torch.cuda.max_memory_allocated()
del trainer
# 2. TorchDynamo nvfuser
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
a.grad = None
@ -1876,7 +1892,11 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
for _ in range(10):
loss = trainer.training_step(mod, {"x": a})
# resets
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
loss = trainer.training_step(mod, {"x": a})
peak_mem = torch.cuda.max_memory_allocated()
del trainer