start from 1.12, torch_ccl is renamed as oneccl_bindings_for_pytorch … (#18229)

* start from 1.12, torch_ccl is renamed as oneccl_bindings_for_pytorch and should import it before use

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add doc for perf_train_cpu_many

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update doc

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2022-07-27 23:15:41 +08:00 committed by GitHub
parent e87ac9d18b
commit 2b81f72be9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 128 additions and 5 deletions

View File

@ -68,6 +68,8 @@
title: Training on many GPUs
- local: perf_train_cpu
title: Training on CPU
- local: perf_train_cpu_many
title: Training on many CPUs
- local: perf_train_tpu
title: Training on TPUs
- local: perf_train_special

View File

@ -0,0 +1,92 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
-->
# Efficient Training on Multiple CPUs
When training on a single CPU is too slow, we can use multiple CPUs. This guide focuses on PyTorch-based DDP enabling distributed CPU training efficiently.
## Intel® oneCCL Bindings for PyTorch
[Intel® oneCCL](https://github.com/oneapi-src/oneCCL) (collective communications library) is a library for efficient distributed deep learning training implementing such collectives like allreduce, allgather, alltoall. For more information on oneCCL, please refer to the [oneCCL documentation](https://spec.oneapi.com/versions/latest/elements/oneCCL/source/index.html) and [oneCCL specification](https://spec.oneapi.com/versions/latest/elements/oneCCL/source/index.html).
Module `oneccl_bindings_for_pytorch` (`torch_ccl` before version 1.12) implements PyTorch C10D ProcessGroup API and can be dynamically loaded as external ProcessGroup and only works on Linux platform now
Check more detailed information for [oneccl_bind_pt](https://github.com/intel/torch-ccl).
### Intel® oneCCL Bindings for PyTorch installation:
Wheel files are available for the following Python versions:
| Extension Version | Python 3.6 | Python 3.7 | Python 3.8 | Python 3.9 | Python 3.10 |
| :---------------: | :--------: | :--------: | :--------: | :--------: | :---------: |
| 1.12.0 | | √ | √ | √ | √ |
| 1.11.0 | | √ | √ | √ | √ |
| 1.10.0 | √ | √ | √ | √ | |
```
pip install oneccl_bind_pt=={pytorch_version} -f https://software.intel.com/ipex-whl-stable
```
where `{pytorch_version}` should be your PyTorch version, for instance 1.12.0.
Check more approaches for [oneccl_bind_pt installation](https://github.com/intel/torch-ccl).
### Usage in Trainer
To enable multi CPU distributed training in the Trainer with the ccl backend, users should add **`--xpu_backend ccl`** in the command arguments.
Let's see an example with the [question-answering example](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering)
The following command enables training with 2 processes on one Xeon node, with one process running per one socket. The variables OMP_NUM_THREADS/CCL_WORKER_COUNT can be tuned for optimal performance.
```shell script
export CCL_WORKER_COUNT=1
export MASTER_ADDR=127.0.0.1
mpirun -n 2 -genv OMP_NUM_THREADS=23 \
python3 run_qa.py \
--model_name_or_path bert-large-uncased \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/debug_squad/ \
--no_cuda \
--xpu_backend ccl
```
The following command enables training with a total of four processes on two Xeons (node0 and node1, taking node0 as the main process), ppn (processes per node) is set to 2, with one process running per one socket. The variables OMP_NUM_THREADS/CCL_WORKER_COUNT can be tuned for optimal performance.
In node0, you need to create a configuration file which contains the IP addresses of each node (for example hostfile) and pass that configuration file path as an argument.
```shell script
cat hostfile
xxx.xxx.xxx.xxx #node0 ip
xxx.xxx.xxx.xxx #node1 ip
```
Now, run the following command in node0 and **4DDP** will be enabled in node0 and node1:
```shell script
export CCL_WORKER_COUNT=1
export MASTER_ADDR=xxx.xxx.xxx.xxx #node0 ip
mpirun -f hostfile -n 4 -ppn 2 \
-genv OMP_NUM_THREADS=23 \
python3 run_qa.py \
--model_name_or_path bert-large-uncased \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/debug_squad/ \
--no_cuda \
--xpu_backend ccl
```

View File

@ -34,6 +34,7 @@ from .trainer_utils import (
from .utils import (
ExplicitEnum,
cached_property,
ccl_version,
get_full_repo_name,
is_accelerate_available,
is_sagemaker_dp_enabled,
@ -44,6 +45,7 @@ from .utils import (
is_torch_tf32_available,
is_torch_tpu_available,
logging,
requires_backends,
torch_required,
)
@ -1301,11 +1303,17 @@ class TrainingArguments:
"CPU distributed training backend is not properly set. "
"Please set '--xpu_backend' to either 'mpi' or 'ccl'."
)
if self.xpu_backend == "ccl" and int(os.environ.get("CCL_WORKER_COUNT", 0)) < 1:
raise ValueError(
"CPU distributed training backend is ccl. but CCL_WORKER_COUNT is not correctly set. "
"Please use like 'export CCL_WORKER_COUNT = 1' to set."
)
if self.xpu_backend == "ccl":
requires_backends(self, "oneccl_bind_pt")
if ccl_version >= "1.12":
import oneccl_bindings_for_pytorch # noqa: F401
else:
import torch_ccl # noqa: F401
if int(os.environ.get("CCL_WORKER_COUNT", 0)) < 1:
raise ValueError(
"CPU distributed training backend is ccl. but CCL_WORKER_COUNT is not correctly set. "
"Please use like 'export CCL_WORKER_COUNT = 1' to set."
)
# Try to get launch configuration from environment variables set by MPI launcher - works for Intel MPI, OpenMPI and MVAPICH
rank = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0)

View File

@ -87,6 +87,7 @@ from .import_utils import (
DummyObject,
OptionalDependencyNotAvailable,
_LazyModule,
ccl_version,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,

View File

@ -246,6 +246,16 @@ try:
except importlib_metadata.PackageNotFoundError:
_librosa_available = False
ccl_version = "N/A"
_is_ccl_available = (
importlib.util.find_spec("torch_ccl") is not None
or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
)
try:
ccl_version = importlib_metadata.version("oneccl_bind_pt")
logger.debug(f"Successfully imported oneccl_bind_pt version {ccl_version}")
except importlib_metadata.PackageNotFoundError:
_is_ccl_available = False
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
@ -636,6 +646,10 @@ def torch_only_method(fn):
return wrapper
def is_ccl_available():
return _is_ccl_available
# docstyle-ignore
DATASETS_IMPORT_ERROR = """
{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
@ -854,6 +868,11 @@ ACCELERATE_IMPORT_ERROR = """
`pip install accelerate`
"""
# docstyle-ignore
CCL_IMPORT_ERROR = """
{0} requires the torch ccl library but it was not found in your environment. You can install it with pip:
`pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable`
"""
BACKENDS_MAPPING = OrderedDict(
[
@ -882,6 +901,7 @@ BACKENDS_MAPPING = OrderedDict(
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)),
]
)