Refactoring Trainer, adds `save_only_model` arg and simplifying FSDP integration (#27652)

* add code changes

1. Refactor FSDP
2. Add `--save_only_model` option: When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state.
3. Bump up the minimum `accelerate` version to `0.21.0`

* quality

* fix quality?

* Revert "fix quality?"

This reverts commit 149330a6ab.

* fix fsdp doc strings

* fix quality

* Update src/transformers/training_args.py

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* please fix the quality issue 😅

* Apply suggestions from code review

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* address comment

* simplify conditional check as per the comment

* update documentation

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
Sourab Mangrulkar 2023-11-24 11:40:52 +05:30 committed by GitHub
parent b8db265bc6
commit a761d6e9a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 163 additions and 173 deletions

View File

@ -426,8 +426,7 @@ To read more about it and the benefits, check out the [Fully Sharded Data Parall
We have integrated the latest PyTorch's Fully Sharded Data Parallel (FSDP) training feature.
All you need to do is enable it through the config.
**Required PyTorch version for FSDP support**: PyTorch Nightly (or 1.12.0 if you read this after it has been released)
as the model saving with FSDP activated is only available with recent fixes.
**Required PyTorch version for FSDP support**: PyTorch >=2.1.0
**Usage**:
@ -440,6 +439,8 @@ as the model saving with FSDP activated is only available with recent fixes.
- SHARD_GRAD_OP : Shards optimizer states + gradients across data parallel workers/GPUs.
For this, add `--fsdp shard_grad_op` to the command line arguments.
- NO_SHARD : No sharding. For this, add `--fsdp no_shard` to the command line arguments.
- HYBRID_SHARD : No sharding. For this, add `--fsdp hybrid_shard` to the command line arguments.
- HYBRID_SHARD_ZERO2 : No sharding. For this, add `--fsdp hybrid_shard_zero2` to the command line arguments.
- To offload the parameters and gradients to the CPU,
add `--fsdp "full_shard offload"` or `--fsdp "shard_grad_op offload"` to the command line arguments.
- To automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`,
@ -449,18 +450,18 @@ as the model saving with FSDP activated is only available with recent fixes.
- Remaining FSDP config is passed via `--fsdp_config <path_to_fsdp_config.json>`. It is either a location of
FSDP json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`.
- If auto wrapping is enabled, you can either use transformer based auto wrap policy or size based auto wrap policy.
- For transformer based auto wrap policy, it is recommended to specify `fsdp_transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
- For transformer based auto wrap policy, it is recommended to specify `transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit.
Therefore, use this for transformer based models.
- For size based auto wrap policy, please add `fsdp_min_num_params` in the config file.
- For size based auto wrap policy, please add `min_num_params` in the config file.
It specifies FSDP's minimum number of parameters for auto wrapping.
- `fsdp_backward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters.
- `backward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters.
`backward_pre` and `backward_pos` are available options.
For more information refer `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`
- `fsdp_forward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters.
- `forward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters.
If `"True"`, FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass.
- `limit_all_gathers` can be specified in the config file.
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers.
@ -468,6 +469,20 @@ as the model saving with FSDP activated is only available with recent fixes.
If `"True"`, FSDP activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time
for reduced memory usage.
- `use_orig_params` can be specified in the config file.
If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. This also enables to have different optimizer param groups. This should be `True` when creating optimizer object before preparing/wrapping the model with FSDP.
Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019).
**Saving and loading**
Saving entire intermediate checkpoints using `FULL_STATE_DICT` state_dict_type with CPU offloading on rank 0 takes a lot of time and often results in NCCL Timeout errors due to indefinite hanging during broadcasting. However, at the end of training, we want the whole model state dict instead of the sharded state dict which is only compatible with FSDP. Use `SHARDED_STATE_DICT` (default) state_dict_type to save the intermediate checkpoints and optimizer states in this format recommended by the PyTorch team.
Saving the final checkpoint in transformers format using default `safetensors` format requires below changes.
```python
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model(script_args.output_dir)
```
**Few caveats to be aware of**
- it is incompatible with `generate`, thus is incompatible with `--predict_with_generate`
@ -492,15 +507,15 @@ Pass `--fsdp "full shard"` along with following changes to be made in `--fsdp_co
https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py).
- `xla_fsdp_grad_ckpt`. When `True`, uses gradient checkpointing over each nested XLA FSDP wrapped layer.
This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through
`fsdp_min_num_params` or `fsdp_transformer_layer_cls_to_wrap`.
`min_num_params` or `transformer_layer_cls_to_wrap`.
- You can either use transformer based auto wrap policy or size based auto wrap policy.
- For transformer based auto wrap policy, it is recommended to specify `fsdp_transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
- For transformer based auto wrap policy, it is recommended to specify `transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit.
Therefore, use this for transformer based models.
- For size based auto wrap policy, please add `fsdp_min_num_params` in the config file.
- For size based auto wrap policy, please add `min_num_params` in the config file.
It specifies FSDP's minimum number of parameters for auto wrapping.

View File

@ -96,7 +96,7 @@ if stale_egg_info.exists():
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps = [
"Pillow>=10.0.1,<=15.0",
"accelerate>=0.20.3",
"accelerate>=0.21.0",
"av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream.
"beautifulsoup4",
"codecarbon==1.2.0",

View File

@ -3,7 +3,7 @@
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow>=10.0.1,<=15.0",
"accelerate": "accelerate>=0.20.3",
"accelerate": "accelerate>=0.21.0",
"av": "av==9.2.0",
"beautifulsoup4": "beautifulsoup4",
"codecarbon": "codecarbon==1.2.0",

View File

@ -132,8 +132,12 @@ def is_fsdp_enabled():
)
def is_fsdp_enabled_and_dist_rank_0():
return is_fsdp_enabled() and int(os.environ.get("LOCAL_RANK", -1)) == 0
def is_local_dist_rank_0():
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and int(os.environ.get("LOCAL_RANK", -1)) == 0
)
if is_sagemaker_mp_enabled():
@ -474,13 +478,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
return safe_load_file(checkpoint_file)
try:
if (
(is_deepspeed_zero3_enabled() or is_fsdp_enabled())
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > 0
):
is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0
) or (is_fsdp_enabled() and not is_local_dist_rank_0()):
map_location = "meta"
else:
map_location = "cpu"
return torch.load(checkpoint_file, map_location=map_location)
except Exception as e:
try:
@ -3904,7 +3907,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
ignore_mismatched_sizes,
)
if low_cpu_mem_usage:
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
if is_fsdp_enabled() and not is_local_dist_rank_0():
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
if not (is_quantized):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
set_module_quantized_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
@ -3922,17 +3936,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
else:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
if not (is_quantized):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
set_module_quantized_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

View File

@ -99,7 +99,6 @@ from .trainer_utils import (
BestRun,
EvalLoopOutput,
EvalPrediction,
FSDPOption,
HPSearchBackend,
HubStrategy,
IntervalStrategy,
@ -193,15 +192,15 @@ if is_peft_available():
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version
from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin
from accelerate.utils import (
DistributedDataParallelKwargs,
GradientAccumulationPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
save_fsdp_optimizer,
)
if version.parse(accelerate_version) > version.parse("0.20.3"):
from accelerate.utils import (
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
save_fsdp_optimizer,
)
DATA_SAMPLERS = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
from accelerate.data_loader import SeedableRandomSampler
@ -226,6 +225,7 @@ OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
FSDP_MODEL_NAME = "pytorch_model_fsdp"
class Trainer:
@ -415,7 +415,7 @@ class Trainer:
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
)
self.fsdp = None
self.is_fsdp_xla_enabled = args.fsdp_config["xla"]
if len(args.fsdp) > 0:
if self.is_deepspeed_enabled:
raise ValueError(
@ -424,32 +424,6 @@ class Trainer:
if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
raise ValueError("Using fsdp only works in distributed training.")
# dep_version_check("torch>=1.12.0")
# Would have to update setup.py with torch>=1.12.0
# which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
# below is the current alternative.
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
raise ValueError("FSDP requires PyTorch >= 1.12.0")
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
if FSDPOption.FULL_SHARD in args.fsdp:
self.fsdp = ShardingStrategy.FULL_SHARD
elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
self.fsdp = ShardingStrategy.SHARD_GRAD_OP
elif FSDPOption.NO_SHARD in args.fsdp:
self.fsdp = ShardingStrategy.NO_SHARD
self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get(
"backward_prefetch", []
):
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
self.limit_all_gathers = False
if self.args.fsdp_config.get("limit_all_gathers", False):
self.limit_all_gathers = True
# one place to sort out whether to place the model on device or not
# postpone switching model to cuda when:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
@ -462,7 +436,7 @@ class Trainer:
self.is_model_parallel
or self.is_deepspeed_enabled
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
or (self.fsdp is not None)
or self.is_fsdp_xla_enabled
or self.is_fsdp_enabled
):
self.place_model_on_device = False
@ -513,7 +487,7 @@ class Trainer:
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
)
if (self.is_deepspeed_enabled or (self.fsdp is not None)) and (
if (self.is_deepspeed_enabled or self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
self.optimizer is not None or self.lr_scheduler is not None
):
raise RuntimeError(
@ -1367,7 +1341,7 @@ class Trainer:
# Distributed training (should be after apex fp16 initialization)
# Distributed training using PyTorch FSDP
if self.fsdp is not None and self.args.fsdp_config["xla"]:
if self.is_fsdp_xla_enabled:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
@ -1626,7 +1600,7 @@ class Trainer:
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
@ -1676,8 +1650,6 @@ class Trainer:
use_accelerator_prepare = True if model is self.model else False
if delay_optimizer_creation:
if use_accelerator_prepare:
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare
@ -1895,9 +1867,7 @@ class Trainer:
):
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
# in accelerate. So, explicitly enable sync gradients to True in that case.
if is_last_step_and_steps_less_than_grad_acc or (
version.parse(accelerate_version) <= version.parse("0.20.3")
):
if is_last_step_and_steps_less_than_grad_acc:
self.accelerator.gradient_state._set_sync_gradients(True)
# Gradient clipping
@ -2051,7 +2021,7 @@ class Trainer:
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any(
WEIGHTS_NAME.split(".")[0] in folder_name
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
)
@ -2360,56 +2330,12 @@ class Trainer:
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)
if self.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True
self.model_wrapped.save_checkpoint(output_dir)
# Save optimizer and scheduler
if self.fsdp or self.is_fsdp_enabled:
if self.is_fsdp_enabled:
save_fsdp_optimizer(
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
)
else:
# FSDP has a different interface for saving optimizer states.
# Needs to be called on all ranks to gather all states.
# full_optim_state_dict will be deprecated after Pytorch 2.2!
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
smp.barrier()
if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
smp.save(
opt_state_dict,
os.path.join(output_dir, OPTIMIZER_NAME),
partial=True,
v3=smp.state.cfg.shard_optimizer_state,
)
elif self.args.should_save and not self.is_deepspeed_enabled and not (self.fsdp or self.is_fsdp_enabled):
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
# Save SCHEDULER & SCALER
is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
self.lr_scheduler, DeepSpeedSchedulerWrapper
)
if (
self.args.should_save
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
and not is_torch_tpu_available()
):
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(output_dir)
# Save RNG state
self._save_rng_state(output_dir)
# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
@ -2431,6 +2357,14 @@ class Trainer:
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
def _save_rng_state(self, output_dir):
# Save RNG state in non-distributed training
rng_states = {
"python": random.getstate(),
@ -2462,12 +2396,49 @@ class Trainer:
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
def _save_optimizer_and_scheduler(self, output_dir):
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
smp.barrier()
if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
smp.save(
opt_state_dict,
os.path.join(output_dir, OPTIMIZER_NAME),
partial=True,
v3=smp.state.cfg.shard_optimizer_state,
)
elif self.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True
self.model_wrapped.save_checkpoint(output_dir)
elif self.is_fsdp_enabled:
# save fsdp specific ckpt for resuming from ckpt
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
save_fsdp_optimizer(
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
)
elif self.args.should_save:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
# Save SCHEDULER & SCALER
is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
self.lr_scheduler, DeepSpeedSchedulerWrapper
)
if (
self.args.should_save
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
and not is_torch_tpu_available()
):
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
@ -2535,23 +2506,14 @@ class Trainer:
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
# likely to get OOM on CPU (since we load num_gpu times the optimizer state
map_location = self.args.device if self.args.world_size > 1 else "cpu"
if self.fsdp or self.is_fsdp_enabled:
if self.is_fsdp_enabled:
load_fsdp_optimizer(
self.accelerator.state.fsdp_plugin,
self.accelerator,
self.optimizer,
self.model,
checkpoint,
)
else:
full_osd = None
# In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
if self.args.process_index == 0:
full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))
# call scatter_full_optim_state_dict on all ranks
sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)
self.optimizer.load_state_dict(sharded_osd)
if self.is_fsdp_enabled:
load_fsdp_optimizer(
self.accelerator.state.fsdp_plugin,
self.accelerator,
self.optimizer,
self.model,
checkpoint,
)
else:
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
@ -2826,19 +2788,14 @@ class Trainer:
if IS_SAGEMAKER_MP_POST_1_10:
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10
Path(os.path.join(output_dir, "user_content.pt")).touch()
elif self.fsdp is not None or self.is_fsdp_enabled:
state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {}
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if self.is_fsdp_enabled:
# remove the dummy state_dict
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
elif self.is_fsdp_enabled:
if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and (
version.parse(accelerate_version) > version.parse("0.24.1")
):
state_dict = self.accelerator.get_state_dict(self.model)
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif self.is_deepspeed_enabled:
# this takes care of everything as long as we aren't under zero3
if version.parse(accelerate_version) <= version.parse("0.20.3"):
raise ValueError("Install Accelerate from main branch")
try:
state_dict = self.accelerator.get_state_dict(self.deepspeed)
if self.args.should_save:
@ -3247,11 +3204,7 @@ class Trainer:
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if (
args.eval_accumulation_steps is not None
and (step + 1) % args.eval_accumulation_steps == 0
and (self.accelerator.sync_gradients or version.parse(accelerate_version) > version.parse("0.20.3"))
):
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
@ -3877,8 +3830,7 @@ class Trainer:
def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
if version.parse(accelerate_version) > version.parse("0.20.3"):
grad_acc_kwargs["sync_with_dataloader"] = False
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
# create accelerator object

View File

@ -727,6 +727,8 @@ class FSDPOption(ExplicitEnum):
FULL_SHARD = "full_shard"
SHARD_GRAD_OP = "shard_grad_op"
NO_SHARD = "no_shard"
HYBRID_SHARD = "hybrid_shard"
HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"

View File

@ -304,6 +304,11 @@ class TrainingArguments:
This should not be activated when the different nodes use the same storage as the files will be saved with
the same names for each node.
save_only_model (`bool`, *optional*, defaults to `False`):
When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state.
Note that when this is true, you won't be able to resume training from checkpoint.
This enables you to save storage by not storing the optimizer, scheduler & rng state.
You can only load the model using `from_pretrained` with this option set to `True`.
use_cpu (`bool`, *optional*, defaults to `False`):
Whether or not to use cpu. If set to False, we will use cuda or mps device if available.
seed (`int`, *optional*, defaults to 42):
@ -418,12 +423,14 @@ class TrainingArguments:
- `"full_shard"`: Shard parameters, gradients and optimizer states.
- `"shard_grad_op"`: Shard optimizer states and gradients.
- `"hybrid_shard"`: Apply `FULL_SHARD` within a node, and replicate parameters across nodes.
- `"hybrid_shard_zero2"`: Apply `SHARD_GRAD_OP` within a node, and replicate parameters across nodes.
- `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
`"shard_grad_op"`).
- `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
fsdp_config (`str` or `dict`, *optional*):
Config to be used with fsdp (Pytorch Distributed Parallel Training). The value is either a location of
deepspeed json config file (e.g., `ds_config.json`) or an already loaded json file as `dict`.
fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`.
A List of config and its options:
- min_num_params (`int`, *optional*, defaults to `0`):
@ -452,7 +459,7 @@ class TrainingArguments:
FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
all-gathers.
- use_orig_params (`bool`, *optional*, defaults to `False`)
- use_orig_params (`bool`, *optional*, defaults to `True`)
If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed
frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please
refer this
@ -460,6 +467,10 @@ class TrainingArguments:
- sync_module_states (`bool`, *optional*, defaults to `True`)
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
ensure they are the same across all ranks after initialization
- activation_checkpointing (`bool`, *optional*, defaults to `False`):
If `"True"`, activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra
computation time for reduced memory usage.
- xla (`bool`, *optional*, defaults to `False`):
Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
and its API may evolve in the future.
@ -472,10 +483,6 @@ class TrainingArguments:
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
- activation_checkpointing (`bool`, *optional*, defaults to `False`):
If True, activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra
computation time for reduced memory usage.
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
@ -835,6 +842,17 @@ class TrainingArguments:
)
},
)
save_only_model: bool = field(
default=False,
metadata={
"help": (
"When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state."
"Note that when this is true, you won't be able to resume training from checkpoint."
"This enables you to save storage by not storing the optimizer, scheduler & rng state."
"You can only load the model using from_pretrained with this option set to True."
)
},
)
no_cuda: bool = field(
default=False,
metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."},
@ -1670,7 +1688,7 @@ class TrainingArguments:
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefect", "false")
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false")
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")
if self.tpu_metrics_debug:
warnings.warn(

View File

@ -652,7 +652,7 @@ def is_protobuf_available():
return importlib.util.find_spec("google.protobuf") is not None
def is_accelerate_available(min_version: str = None):
def is_accelerate_available(min_version: str = "0.21.0"):
if min_version is not None:
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
return _accelerate_available