Introduce configured_state arg for accelerator_config (#29781)

* Introduce configured_state

* Include note on tuning

* Allow for users to have defined a state already

* Include tests

* Add note on hpam tune

* Guard a bit better

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Finish rebase

* Finish rebase

* Guard carefully

* Fixup test

* Refactor

* Fin refactor

* Comment

* Update wrt feedback

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Zach Mueller 2024-05-20 09:21:40 -04:00 committed by GitHub
parent bb48e92186
commit 92d1d97c05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 110 additions and 52 deletions

View File

@ -1250,6 +1250,10 @@ class AcceleratorConfig:
Whether to use non-blocking CUDA calls to help minimize synchronization during Whether to use non-blocking CUDA calls to help minimize synchronization during
distributed training with prepared `DataLoader` inputs being moved to device. distributed training with prepared `DataLoader` inputs being moved to device.
Best if used with `pin_memory=True` in the `TrainingArguments`. Best if used with `pin_memory=True` in the `TrainingArguments`.
use_configured_state (`bool*, *optional*, defaults to `False`):
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined
before calling `TrainingArguments`. If `True`, an `Accelerator` or `PartialState`
must be initialized. May lead to issues using sweeps or hyperparameter tuning.
""" """
@ -1312,6 +1316,13 @@ class AcceleratorConfig:
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`." " The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`."
}, },
) )
use_configured_state: bool = field(
default=False,
metadata={
"help": "Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`."
"If `True`, an `Accelerator` or `PartialState` must be initialized. May lead to issues using sweeps or hyperparameter tuning."
},
)
@classmethod @classmethod
def from_json_file(cls, json_file): def from_json_file(cls, json_file):
@ -1331,6 +1342,9 @@ class AcceleratorConfig:
def to_dict(self): def to_dict(self):
return copy.deepcopy(self.__dict__) return copy.deepcopy(self.__dict__)
def pop(self, key, default=None):
return self.__dict__.pop(key, default)
class LayerWiseDummyOptimizer(torch.optim.Optimizer): class LayerWiseDummyOptimizer(torch.optim.Optimizer):
""" """

View File

@ -572,6 +572,10 @@ class TrainingArguments:
training results are fully reproducable using a different sampling technique. While seed-to-seed results training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results. also be ran with [`~utils.set_seed`] for the best results.
- use_configured_state (`bool`, *optional*, defaults to `False`):
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`.
If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues
with hyperparameter tuning.
label_smoothing_factor (`float`, *optional*, defaults to 0.0): label_smoothing_factor (`float`, *optional*, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
@ -1635,6 +1639,39 @@ class TrainingArguments:
if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16: if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16:
raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0") raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0")
# We need to setup the accelerator config here *before* the first call to `self.device`
if is_accelerate_available():
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
if self.accelerator_config is None:
self.accelerator_config = AcceleratorConfig()
elif isinstance(self.accelerator_config, dict):
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
# Check that a user didn't pass in the class instantiator
# such as `accelerator_config = AcceleratorConfig`
elif isinstance(self.accelerator_config, type):
raise NotImplementedError(
"Tried passing in a callable to `accelerator_config`, but this is not supported. "
"Please pass in a fully constructed `AcceleratorConfig` object instead."
)
else:
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
if self.dispatch_batches is not None:
warnings.warn(
"Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'dispatch_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config.dispatch_batches = self.dispatch_batches
if self.split_batches is not None:
warnings.warn(
"Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'split_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config.split_batches = self.split_batches
if ( if (
self.framework == "pt" self.framework == "pt"
and is_torch_available() and is_torch_available()
@ -1873,37 +1910,6 @@ class TrainingArguments:
os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower() os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
if is_accelerate_available():
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
if self.accelerator_config is None:
self.accelerator_config = AcceleratorConfig()
elif isinstance(self.accelerator_config, dict):
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
# Check that a user didn't pass in the class instantiator
# such as `accelerator_config = AcceleratorConfig`
elif isinstance(self.accelerator_config, type):
raise NotImplementedError(
"Tried passing in a callable to `accelerator_config`, but this is not supported. "
"Please pass in a fully constructed `AcceleratorConfig` object instead."
)
else:
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
if self.dispatch_batches is not None:
warnings.warn(
"Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'dispatch_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config.dispatch_batches = self.dispatch_batches
if self.split_batches is not None:
warnings.warn(
"Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'split_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config.split_batches = self.split_batches
if self.tpu_metrics_debug: if self.tpu_metrics_debug:
warnings.warn( warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
@ -2056,32 +2062,62 @@ class TrainingArguments:
f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: " f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
"Please run `pip install transformers[torch]` or `pip install accelerate -U`" "Please run `pip install transformers[torch]` or `pip install accelerate -U`"
) )
# We delay the init of `PartialState` to the end for clarity
accelerator_state_kwargs = {"enabled": True, "use_configured_state": False}
if isinstance(self.accelerator_config, AcceleratorConfig):
accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop(
"use_configured_state", False
)
if accelerator_state_kwargs["use_configured_state"]:
if PartialState._shared_state == {}:
raise ValueError(
"Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured "
"`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. "
)
# We rely on `PartialState` to yell if there's issues here (which it will)
self.distributed_state = PartialState(cpu=self.use_cpu)
if self.deepspeed and self.distributed_state.distributed_type != DistributedType.DEEPSPEED:
raise RuntimeError(
"Tried to use an already configured `Accelerator` or `PartialState` that was not initialized for DeepSpeed, "
"but also passed in a `deepspeed` configuration to the `TrainingArguments`. Please set "
"`use_configured_state:False` instead or setup your `Accelerator` or `PartialState` properly."
)
else:
AcceleratorState._reset_state(reset_partial_state=True) AcceleratorState._reset_state(reset_partial_state=True)
self.distributed_state = None self.distributed_state = None
if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ: if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ:
os.environ["ACCELERATE_USE_IPEX"] = "false" os.environ["ACCELERATE_USE_IPEX"] = "false"
self._n_gpu = 1
if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")): if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")):
self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend) accelerator_state_kwargs["cpu"] = True
accelerator_state_kwargs["backend"] = self.ddp_backend
self._n_gpu = 0 self._n_gpu = 0
elif is_sagemaker_mp_enabled(): elif is_sagemaker_mp_enabled():
accelerator_state_kwargs["enabled"] = False
local_rank = smp.local_rank() local_rank = smp.local_rank()
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
self._n_gpu = 1
torch.cuda.set_device(device) torch.cuda.set_device(device)
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
self.distributed_state = PartialState(_use_sagemaker_dp=True) accelerator_state_kwargs["_use_sagemaker_dp"] = True
self._n_gpu = 1
elif self.deepspeed: elif self.deepspeed:
# Need to do similar for Accelerator init accelerator_state_kwargs["use_deepspeed"] = True
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
del os.environ["ACCELERATE_USE_DEEPSPEED"]
self._n_gpu = 1
else: else:
self.distributed_state = PartialState( accelerator_state_kwargs["backend"] = self.ddp_backend
backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout) accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)
)
self._n_gpu = 1 # Now we pop everything
if accelerator_state_kwargs.pop("enabled", False) and not accelerator_state_kwargs.pop(
"use_configured_state", False
):
# We need to patch this env var when enabling to detect deepspeed
use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False)
if use_deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
self.distributed_state = PartialState(**accelerator_state_kwargs)
if use_deepspeed:
del os.environ["ACCELERATE_USE_DEEPSPEED"]
if not is_sagemaker_mp_enabled(): if not is_sagemaker_mp_enabled():
device = self.distributed_state.device device = self.distributed_state.device
self.local_rank = self.distributed_state.local_process_index self.local_rank = self.distributed_state.local_process_index
@ -2108,23 +2144,17 @@ class TrainingArguments:
"Either you do not have an MPS-enabled device on this machine or MacOS version is not 12.3+ " "Either you do not have an MPS-enabled device on this machine or MacOS version is not 12.3+ "
"or current PyTorch install was not built with MPS enabled." "or current PyTorch install was not built with MPS enabled."
) )
if device.type == "mps": if self.use_cpu:
self._n_gpu = 1
elif self.use_cpu:
device = torch.device("cpu") device = torch.device("cpu")
self._n_gpu = 0
elif is_torch_xpu_available(): elif is_torch_xpu_available():
device = torch.device("xpu:0") device = torch.device("xpu:0")
torch.xpu.set_device(device) torch.xpu.set_device(device)
self._n_gpu = 1
elif is_torch_mlu_available(): elif is_torch_mlu_available():
device = torch.device("mlu:0") device = torch.device("mlu:0")
torch.mlu.set_device(device) torch.mlu.set_device(device)
self._n_gpu = 1
elif is_torch_npu_available(): elif is_torch_npu_available():
device = torch.device("npu:0") device = torch.device("npu:0")
torch.npu.set_device(device) torch.npu.set_device(device)
self._n_gpu = 1
else: else:
# if n_gpu is > 1 we'll use nn.DataParallel. # if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`

View File

@ -131,6 +131,10 @@ if is_torch_available():
# for version specific tests in TrainerIntegrationTest # for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
if is_accelerate_available():
from accelerate import Accelerator
from accelerate.state import AcceleratorState
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
@ -3266,6 +3270,16 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True) self.assertEqual(trainer.accelerator.split_batches, True)
def test_accelerator_custom_state(self):
AcceleratorState._reset_state(reset_partial_state=True)
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(ValueError) as cm:
_ = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config={"use_configured_state": True})
self.assertIn("Please define this beforehand", str(cm.warnings[0].message))
_ = Accelerator()
_ = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config={"use_configured_state": True})
AcceleratorState._reset_state(reset_partial_state=True)
@require_accelerate_version_min_0_28 @require_accelerate_version_min_0_28
def test_accelerator_config_from_dict_grad_accum_num_steps(self): def test_accelerator_config_from_dict_grad_accum_num_steps(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir: