Rework TPU checkpointing in Trainer (#10504)

* Rework TPU checkpointing in Trainer

* Wraps the barrier in a dist test

* Address review comments

* Remove line
This commit is contained in:
Sylvain Gugger 2021-03-04 11:46:11 -05:00 committed by GitHub
parent 805c5200dc
commit 6290169eb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 58 deletions

View File

@ -75,8 +75,6 @@ class PretrainedConfig(object):
heads to prune in said layer.
For instance ``{1: [0, 2], 2: [2, 3]}`` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
xla_device (:obj:`bool`, `optional`):
A flag to indicate if TPU are available or not.
chunk_size_feed_forward (:obj:`int`, `optional`, defaults to :obj:`0`):
The chunk size of all feed forward layers in the residual attention blocks. A chunk size of :obj:`0` means
that the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes
@ -248,7 +246,11 @@ class PretrainedConfig(object):
self.task_specific_params = kwargs.pop("task_specific_params", None)
# TPU arguments
self.xla_device = kwargs.pop("xla_device", None)
if kwargs.pop("xla_device", None) is not None:
logger.warn(
"The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
"safely remove it from your `config.json` file."
)
# Name or path to the pretrained checkpoint
self._name_or_path = str(kwargs.pop("name_or_path", ""))

View File

@ -37,7 +37,6 @@ from .file_utils import (
cached_path,
hf_bucket_url,
is_remote_url,
is_torch_tpu_available,
replace_return_docstrings,
)
from .generation_utils import GenerationMixin
@ -781,7 +780,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
self.base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
save_config: bool = True,
state_dict: Optional[dict] = None,
save_function: Callable = torch.save,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
@ -789,19 +794,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Arguments:
save_directory (:obj:`str` or :obj:`os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
save_config (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to save the config of the model. Useful when in distributed training like TPUs and need
to call this function on all processes. In this case, set :obj:`save_config=True` only on the main
process to avoid race conditions.
state_dict (nested dictionary of :obj:`torch.Tensor`):
The state dictionary of the model to save. Will default to :obj:`self.state_dict()`, but can be used to
only save parts of the model or if special precautions need to be taken when recovering the state
dictionary of a model (like when using model parallelism).
save_function (:obj:`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace :obj:`torch.save` by another method.
"""
if os.path.isfile(save_directory):
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
os.makedirs(save_directory, exist_ok=True)
# Only save the model itself if we are using distributed training
model_to_save = self.module if hasattr(self, "module") else self
model_to_save = unwrap_model(self)
# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]
state_dict = model_to_save.state_dict()
# Save the config
if save_config:
model_to_save.config.save_pretrained(save_directory)
# Save the model
if state_dict is None:
state_dict = model_to_save.state_dict()
# Handle the case where some state_dict keys shouldn't be saved
if self._keys_to_ignore_on_save is not None:
@ -809,18 +831,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
if getattr(self.config, "xla_device", False) and is_torch_tpu_available():
import torch_xla.core.xla_model as xm
if xm.is_master_ordinal():
# Save configuration file
model_to_save.config.save_pretrained(save_directory)
# xm.save takes care of saving only from master
xm.save(state_dict, output_model_file)
else:
model_to_save.config.save_pretrained(save_directory)
torch.save(state_dict, output_model_file)
save_function(state_dict, output_model_file)
logger.info("Model weights saved in {}".format(output_model_file))
@ -1181,12 +1192,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
}
return model, loading_info
if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
import torch_xla.core.xla_model as xm
model = xm.send_cpu_data_to_device(model, xm.xla_device())
model.to(xm.xla_device())
return model
@ -1634,6 +1639,20 @@ class SequenceSummary(nn.Module):
return output
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (:obj:`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model
def prune_linear_layer(layer: torch.nn.Linear, index: torch.LongTensor, dim: int = 0) -> torch.nn.Linear:
"""
Prune a linear layer to keep only entries in index.

View File

@ -60,7 +60,7 @@ from .file_utils import (
is_sagemaker_distributed_available,
is_torch_tpu_available,
)
from .modeling_utils import PreTrainedModel
from .modeling_utils import PreTrainedModel, unwrap_model
from .optimization import Adafactor, AdamW, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
@ -154,14 +154,6 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def _model_unwrap(model: nn.Module) -> nn.Module:
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return _model_unwrap(model.module)
else:
return model
class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
@ -359,10 +351,6 @@ class Trainer:
# Create output directory if needed
if self.is_world_process_zero():
os.makedirs(self.args.output_dir, exist_ok=True)
if is_torch_tpu_available() and isinstance(self.model, PreTrainedModel):
# Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future.
self.model.config.xla_device = True
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
@ -1194,7 +1182,7 @@ class Trainer:
def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
# assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
@ -1499,13 +1487,15 @@ class Trainer:
"""
Will save the model, so you can reload it using :obj:`from_pretrained()`.
Will only save from the world_master process (unless in TPUs).
Will only save from the main process.
"""
if is_torch_tpu_available():
self._save_tpu(output_dir)
elif self.is_world_process_zero():
self._save(output_dir)
else:
if self.is_world_process_zero():
self._save(output_dir)
if self.args.local_rank != -1:
dist.barrier()
def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
@ -1519,34 +1509,39 @@ class Trainer:
# They can then be reloaded using `from_pretrained()`
xm.rendezvous("saving_checkpoint")
if not isinstance(self.model, PreTrainedModel):
if isinstance(_model_unwrap(self.model), PreTrainedModel):
if xm.is_master_ordinal():
_model_unwrap(self.model).config.save_pretrained(output_dir)
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(
output_dir,
save_config=self.is_world_process_zero(),
state_dict=self.model.state_dict(),
save_function=xm.save,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir)
self.model.save_pretrained(output_dir, save_config=self.is_world_process_zero(), save_function=xm.save)
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)
def _save(self, output_dir: Optional[str] = None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
if isinstance(_model_unwrap(self.model), PreTrainedModel):
_model_unwrap(self.model).config.save_pretrained(output_dir)
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(output_dir, state_dict=self.model.state_dict())
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir)
if self.tokenizer is not None and self.is_world_process_zero():
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model

View File

@ -57,7 +57,7 @@ if is_torch_available():
Trainer,
TrainerState,
)
from transformers.trainer import _model_unwrap
from transformers.modeling_utils import unwrap_model
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
@ -882,8 +882,8 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer = get_regression_trainer(learning_rate=0.1)
def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, _model_unwrap(wrapped_model_to_check))
self.assertGreaterEqual(getattr(_model_unwrap(wrapped_model_to_check).config, "total_flos", 0), 0)
self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check))
self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0)
# with plain model
assert_flos_extraction(trainer, trainer.model)