|
|
|
@ -34,6 +34,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
|
from zipfile import is_zipfile
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from huggingface_hub import split_torch_state_dict_into_shards
|
|
|
|
|
from packaging import version
|
|
|
|
|
from torch import Tensor, nn
|
|
|
|
|
from torch.nn import CrossEntropyLoss, Identity
|
|
|
|
@ -358,6 +359,10 @@ def shard_checkpoint(
|
|
|
|
|
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
|
|
|
|
|
The name of the model save file.
|
|
|
|
|
"""
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using "
|
|
|
|
|
"split_torch_state_dict_into_shards from huggingface_hub library"
|
|
|
|
|
)
|
|
|
|
|
max_shard_size = convert_file_size_to_int(max_shard_size)
|
|
|
|
|
|
|
|
|
|
sharded_state_dicts = [{}]
|
|
|
|
@ -2585,7 +2590,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
else:
|
|
|
|
|
weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
|
|
|
|
|
|
|
|
|
|
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
|
|
|
|
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
|
|
|
|
state_dict_split = split_torch_state_dict_into_shards(
|
|
|
|
|
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
|
|
|
|
)
|
|
|
|
|
# Save index if sharded
|
|
|
|
|
index = None
|
|
|
|
|
if state_dict_split.is_sharded:
|
|
|
|
|
index = {
|
|
|
|
|
"metadata": state_dict_split.metadata,
|
|
|
|
|
"weight_map": state_dict_split.tensor_to_filename,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Clean the folder from a previous save
|
|
|
|
|
for filename in os.listdir(save_directory):
|
|
|
|
@ -2601,14 +2616,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
if (
|
|
|
|
|
filename.startswith(weights_no_suffix)
|
|
|
|
|
and os.path.isfile(full_filename)
|
|
|
|
|
and filename not in shards.keys()
|
|
|
|
|
and filename not in state_dict_split.filename_to_tensors.keys()
|
|
|
|
|
and is_main_process
|
|
|
|
|
and reg.fullmatch(filename_no_suffix) is not None
|
|
|
|
|
):
|
|
|
|
|
os.remove(full_filename)
|
|
|
|
|
|
|
|
|
|
# Save the model
|
|
|
|
|
for shard_file, shard in shards.items():
|
|
|
|
|
for shard_file, tensors in state_dict_split.filename_to_tensors.items():
|
|
|
|
|
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
|
|
|
|
if safe_serialization:
|
|
|
|
|
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
|
|
|
|
# joyfulness), but for now this enough.
|
|
|
|
@ -2628,7 +2644,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
f.write(content)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
|
|
|
|
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
|
|
|
|
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
|
|
|
|
f"index located at {save_index_file}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|