Safetensors serialization by default (#27064)

* Safetensors serialization by default

* First pass on the tests

* Second pass on the tests

* Third pass on the tests

* Fix TF weight loading from TF-format safetensors

* Specific encoder-decoder fixes for weight crossloading

* Add VisionEncoderDecoder fixes for TF too

* Change filename test for pt-to-tf

* One missing fix for TFVisionEncoderDecoder

* Fix the other crossload test

* Support for flax + updated tests

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Sanchit's comments

* Sanchit's comments 2

* Nico's comments

* Fix tests

* cleanup

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Lysandre Debut 2023-10-31 19:16:49 +01:00 committed by GitHub
parent 25e6e9418c
commit 113ebf80ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 433 additions and 137 deletions

View File

@ -27,9 +27,15 @@ from flax.traverse_util import flatten_dict, unflatten_dict
import transformers import transformers
from . import is_safetensors_available
from .utils import logging from .utils import logging
if is_safetensors_available():
from safetensors import safe_open
from safetensors.flax import load_file as safe_load_file
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -56,7 +62,13 @@ def load_pytorch_checkpoint_in_flax_state_dict(
pt_path = os.path.abspath(pytorch_checkpoint_path) pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}") logger.info(f"Loading PyTorch weights from {pt_path}")
pt_state_dict = torch.load(pt_path, map_location="cpu") if pt_path.endswith(".safetensors"):
pt_state_dict = {}
with safe_open(pt_path, framework="pt") as f:
for k in f.keys():
pt_state_dict[k] = f.get_tensor(k)
else:
pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
@ -319,11 +331,15 @@ def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
flax_cls = getattr(transformers, "Flax" + model.__class__.__name__) flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)
# load flax weight dict # load flax weight dict
with open(flax_checkpoint_path, "rb") as state_f: if flax_checkpoint_path.endswith(".safetensors"):
try: flax_state_dict = safe_load_file(flax_checkpoint_path)
flax_state_dict = from_bytes(flax_cls, state_f.read()) flax_state_dict = unflatten_dict(flax_state_dict, sep=".")
except UnpicklingError: else:
raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ") with open(flax_checkpoint_path, "rb") as state_f:
try:
flax_state_dict = from_bytes(flax_cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
return load_flax_weights_in_pytorch_model(model, flax_state_dict) return load_flax_weights_in_pytorch_model(model, flax_state_dict)

View File

@ -39,6 +39,8 @@ from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_d
from .utils import ( from .utils import (
FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_INDEX_NAME,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
PushToHubMixin, PushToHubMixin,
@ -54,8 +56,14 @@ from .utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.import_utils import is_safetensors_available
if is_safetensors_available():
from safetensors import safe_open
from safetensors.flax import load_file as safe_load_file
from safetensors.flax import save_file as safe_save_file
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -422,6 +430,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
```""" ```"""
return self._cast_floating_to(params, jnp.float16, mask) return self._cast_floating_to(params, jnp.float16, mask)
@classmethod
def load_flax_weights(cls, resolved_archive_file):
try:
if resolved_archive_file.endswith(".safetensors"):
state = safe_load_file(resolved_archive_file)
state = unflatten_dict(state, sep=".")
else:
with open(resolved_archive_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
with open(resolved_archive_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ")
return state
@classmethod @classmethod
def load_flax_sharded_weights(cls, shard_files): def load_flax_sharded_weights(cls, shard_files):
""" """
@ -688,7 +721,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): if is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
):
# Load from a safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
elif from_pt and os.path.isfile( elif from_pt and os.path.isfile(
@ -705,6 +743,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME) archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True is_sharded = True
# At this stage we don't have a weight file so we will raise an error. # At this stage we don't have a weight file so we will raise an error.
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
is_sharded = True
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
@ -723,7 +768,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
filename = pretrained_model_name_or_path filename = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path) resolved_archive_file = download_url(pretrained_model_name_or_path)
else: else:
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME if from_pt:
filename = WEIGHTS_NAME
elif is_safetensors_available():
filename = SAFE_WEIGHTS_NAME
else:
filename = FLAX_WEIGHTS_NAME
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
cached_file_kwargs = { cached_file_kwargs = {
@ -741,8 +792,15 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
} }
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not. # result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
# Did not find the safetensors file, let's fallback to Flax.
# No support for sharded safetensors yet, so we'll raise an error if that's all we find.
filename = FLAX_WEIGHTS_NAME
resolved_archive_file = cached_file(
pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **cached_file_kwargs
)
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME: if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file( resolved_archive_file = cached_file(
@ -751,21 +809,26 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
if resolved_archive_file is not None: if resolved_archive_file is not None:
is_sharded = True is_sharded = True
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
elif resolved_archive_file is None and from_pt: if resolved_archive_file is None and from_pt:
resolved_archive_file = cached_file( resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
) )
if resolved_archive_file is not None: if resolved_archive_file is not None:
is_sharded = True is_sharded = True
if resolved_archive_file is None: if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error # Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error
# message. # message.
has_file_kwargs = { has_file_kwargs = {
"revision": revision, "revision": revision,
"proxies": proxies, "proxies": proxies,
"token": token, "token": token,
} }
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
is_sharded = True
raise NotImplementedError(
"Support for sharded checkpoints using safetensors is coming soon!"
)
elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named" f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
@ -798,6 +861,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
if is_local: if is_local:
logger.info(f"loading weights file {archive_file}") logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file resolved_archive_file = archive_file
filename = resolved_archive_file.split(os.path.sep)[-1]
else: else:
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else: else:
@ -821,31 +885,27 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
_commit_hash=commit_hash, _commit_hash=commit_hash,
) )
safetensors_from_pt = False
if filename == SAFE_WEIGHTS_NAME:
with safe_open(resolved_archive_file, framework="flax") as f:
safetensors_metadata = f.metadata()
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
raise OSError(
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
" Make sure you save your model with the `save_pretrained` method."
)
safetensors_from_pt = safetensors_metadata.get("format") == "pt"
# init random models # init random models
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
if from_pt: if from_pt or safetensors_from_pt:
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
else: else:
if is_sharded: if is_sharded:
state = cls.load_flax_sharded_weights(resolved_archive_file) state = cls.load_flax_sharded_weights(resolved_archive_file)
else: else:
try: state = cls.load_flax_weights(resolved_archive_file)
with open(resolved_archive_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
with open(resolved_archive_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.arrays # make sure all arrays are stored as jnp.arrays
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261 # https://github.com/google/flax/issues/1261
@ -1030,6 +1090,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
push_to_hub=False, push_to_hub=False,
max_shard_size="10GB", max_shard_size="10GB",
token: Optional[Union[str, bool]] = None, token: Optional[Union[str, bool]] = None,
safe_serialization: bool = False,
**kwargs, **kwargs,
): ):
""" """
@ -1059,6 +1120,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or through msgpack.
""" """
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
@ -1103,24 +1166,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
self.generation_config.save_pretrained(save_directory) self.generation_config.save_pretrained(save_directory)
# save model # save model
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME
output_model_file = os.path.join(save_directory, weights_name)
shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size) shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size)
# Clean the folder from a previous save # Clean the folder from a previous save
for filename in os.listdir(save_directory): for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename) full_filename = os.path.join(save_directory, filename)
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
if ( if (
filename.startswith(FLAX_WEIGHTS_NAME[:-4]) filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename) and os.path.isfile(full_filename)
and filename not in shards.keys() and filename not in shards.keys()
): ):
os.remove(full_filename) os.remove(full_filename)
if index is None: if index is None:
with open(output_model_file, "wb") as f: if safe_serialization:
params = params if params is not None else self.params params = params if params is not None else self.params
model_bytes = to_bytes(params) flat_dict = flatten_dict(params, sep=".")
f.write(model_bytes) safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"})
else:
with open(output_model_file, "wb") as f:
params = params if params is not None else self.params
model_bytes = to_bytes(params)
f.write(model_bytes)
else: else:
save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME) save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME)

View File

@ -626,11 +626,13 @@ def dtype_byte_size(dtype):
return bit_size // 8 return bit_size // 8
def format_weight_name(name, _prefix=None): def strip_model_name_and_prefix(name, _prefix=None):
if _prefix is not None and name.startswith(_prefix):
name = name[len(_prefix) :]
if name.startswith("/"):
name = name[1:]
if "model." not in name and len(name.split("/")) > 1: if "model." not in name and len(name.split("/")) > 1:
name = "/".join(name.split("/")[1:]) name = "/".join(name.split("/")[1:])
if _prefix is not None:
name = _prefix + "/" + name
return name return name
@ -986,7 +988,7 @@ def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismat
# Read the safetensors file # Read the safetensors file
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive: with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
mismatched_layers = [] mismatched_layers = []
weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights] weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights]
loaded_weight_names = list(safetensors_archive.keys()) loaded_weight_names = list(safetensors_archive.keys())
# Find the missing layers from the high level list of layers # Find the missing layers from the high level list of layers
missing_layers = list(set(weight_names) - set(loaded_weight_names)) missing_layers = list(set(weight_names) - set(loaded_weight_names))
@ -994,7 +996,7 @@ def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismat
unexpected_layers = list(set(loaded_weight_names) - set(weight_names)) unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
for weight in model.weights: for weight in model.weights:
weight_name = format_weight_name(weight.name, _prefix=_prefix) weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix)
if weight_name in loaded_weight_names: if weight_name in loaded_weight_names:
weight_value = safetensors_archive.get_tensor(weight_name) weight_value = safetensors_archive.get_tensor(weight_name)
# Check if the shape of the current weight and the one from the H5 file are different # Check if the shape of the current weight and the one from the H5 file are different
@ -1003,7 +1005,7 @@ def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismat
# If the two shapes are not compatible we raise an issue # If the two shapes are not compatible we raise an issue
try: try:
weight_value = tf.reshape(weight_value, K.int_shape(weight)) weight_value = tf.reshape(weight_value, K.int_shape(weight))
except ValueError as e: except (ValueError, tf.errors.InvalidArgumentError) as e:
if ignore_mismatched_sizes: if ignore_mismatched_sizes:
mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight))) mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
continue continue
@ -2367,7 +2369,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
create_pr (`bool`, *optional*, defaults to `False`): create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit. Whether or not to create a PR with the uploaded files or directly commit.
safe_serialization (`bool`, *optional*, defaults to `False`): safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). Whether to save the model using `safetensors` or the traditional TensorFlow way (that uses `h5`).
token (`str` or `bool`, *optional*): token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
@ -2457,7 +2459,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if index is None: if index is None:
if safe_serialization: if safe_serialization:
state_dict = {format_weight_name(w.name): w.value() for w in self.weights} state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in self.weights}
safe_save_file(state_dict, output_model_file, metadata={"format": "tf"}) safe_save_file(state_dict, output_model_file, metadata={"format": "tf"})
else: else:
self.save_weights(output_model_file) self.save_weights(output_model_file)
@ -2718,13 +2720,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
): ):
# Load from a safetensors checkpoint # Load from a safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
is_sharded = True
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint # Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
@ -2732,6 +2727,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Load from a sharded TF 2.0 checkpoint # Load from a sharded TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
is_sharded = True is_sharded = True
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
is_sharded = True
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
# At this stage we don't have a weight file so we will raise an error. # At this stage we don't have a weight file so we will raise an error.
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile( elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
@ -2784,21 +2786,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not. # result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME: if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case. # Did not find the safetensors file, let's fallback to TF.
# No support for sharded safetensors yet, so we'll raise an error if that's all we find.
filename = TF2_WEIGHTS_NAME
resolved_archive_file = cached_file( resolved_archive_file = cached_file(
pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs
) )
if resolved_archive_file is not None:
is_sharded = True
raise NotImplementedError(
"Support for sharded checkpoints using safetensors is coming soon!"
)
else:
# This repo has no safetensors file of any kind, we switch to TensorFlow.
filename = TF2_WEIGHTS_NAME
resolved_archive_file = cached_file(
pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs
)
if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME: if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file( resolved_archive_file = cached_file(
@ -2821,7 +2814,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"proxies": proxies, "proxies": proxies,
"token": token, "token": token,
} }
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
is_sharded = True
raise NotImplementedError(
"Support for sharded checkpoints using safetensors is coming soon!"
)
elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named" f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
@ -2928,6 +2926,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
output_loading_info=output_loading_info, output_loading_info=output_loading_info,
_prefix=load_weight_prefix, _prefix=load_weight_prefix,
ignore_mismatched_sizes=ignore_mismatched_sizes, ignore_mismatched_sizes=ignore_mismatched_sizes,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
) )
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers

View File

@ -470,10 +470,6 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method." "you save your model with the `save_pretrained` method."
) )
elif metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file) return safe_load_file(checkpoint_file)
try: try:
if ( if (
@ -1934,7 +1930,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
save_function: Callable = torch.save, save_function: Callable = torch.save,
push_to_hub: bool = False, push_to_hub: bool = False,
max_shard_size: Union[int, str] = "5GB", max_shard_size: Union[int, str] = "5GB",
safe_serialization: bool = False, safe_serialization: bool = True,
variant: Optional[str] = None, variant: Optional[str] = None,
token: Optional[Union[str, bool]] = None, token: Optional[Union[str, bool]] = None,
save_peft_format: bool = True, save_peft_format: bool = True,
@ -1975,7 +1971,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
</Tip> </Tip>
safe_serialization (`bool`, *optional*, defaults to `False`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*): variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin. If specified, weights are saved in the format pytorch_model.<variant>.bin.
@ -2736,8 +2732,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
" sure the weights are in PyTorch format." " sure the weights are in PyTorch format."
) )
from_pt = not (from_tf | from_flax)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if from_pipeline is not None: if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline user_agent["using_pipeline"] = from_pipeline
@ -3103,6 +3097,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_commit_hash=commit_hash, _commit_hash=commit_hash,
) )
if (
is_safetensors_available()
and isinstance(resolved_archive_file, str)
and resolved_archive_file.endswith(".safetensors")
):
with safe_open(resolved_archive_file, framework="pt") as f:
metadata = f.metadata()
if metadata.get("format") == "pt":
pass
elif metadata.get("format") == "tf":
from_tf = True
logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.")
elif metadata.get("format") == "flax":
from_flax = True
logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
else:
raise ValueError(
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}"
)
from_pt = not (from_tf | from_flax)
# load pt weights early so that we know which dtype to init the model under # load pt weights early so that we know which dtype to init the model under
if from_pt: if from_pt:
if not is_sharded and state_dict is None: if not is_sharded and state_dict is None:
@ -3391,7 +3408,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# restore default dtype # restore default dtype
if dtype_orig is not None: if dtype_orig is not None:
torch.set_default_dtype(dtype_orig) torch.set_default_dtype(dtype_orig)
( (
model, model,
missing_keys, missing_keys,

View File

@ -366,8 +366,8 @@ class EncoderDecoderModel(PreTrainedModel):
model.config = config model.config = config
if hasattr(model, "enc_to_dec_proj"): if hasattr(model, "enc_to_dec_proj"):
model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous()
model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous()
return model return model

View File

@ -306,17 +306,21 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
# here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
# not the case, and I wasn't sure how else to go from the config to the correct MainLayer name! # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
if kwargs.get("from_pt", False): # This override is only needed in the case where we're crossloading weights from PT. However, since weights are
config = AutoConfig.from_pretrained(pretrained_model_name_or_path) # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file.
encoder_model_type = config.encoder.model_type # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it
# or not.
def tf_to_pt_weight_rename(tf_weight): config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
if "encoder" in tf_weight and "decoder" not in tf_weight: encoder_model_type = config.encoder.model_type
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
else:
return tf_weight
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename def tf_to_pt_weight_rename(tf_weight):
if "encoder" in tf_weight and "decoder" not in tf_weight:
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
else:
return tf_weight
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@classmethod @classmethod

View File

@ -322,17 +322,21 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
# here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
# not the case, and I wasn't sure how else to go from the config to the correct MainLayer name! # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
if kwargs.get("from_pt", False): # This override is only needed in the case where we're crossloading weights from PT. However, since weights are
config = AutoConfig.from_pretrained(pretrained_model_name_or_path) # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file.
encoder_model_type = config.encoder.model_type # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it
# or not.
def tf_to_pt_weight_rename(tf_weight): config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
if "encoder" in tf_weight and "decoder" not in tf_weight: encoder_model_type = config.encoder.model_type
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
else:
return tf_weight
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename def tf_to_pt_weight_rename(tf_weight):
if "encoder" in tf_weight and "decoder" not in tf_weight:
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
else:
return tf_weight
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@classmethod @classmethod

View File

@ -342,8 +342,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
model.config = config model.config = config
if hasattr(model, "enc_to_dec_proj"): if hasattr(model, "enc_to_dec_proj"):
model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous()
model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous()
return model return model

View File

@ -836,7 +836,7 @@ class Pipeline(_ScikitCompat):
# then we should keep working # then we should keep working
self.image_processor = self.feature_extractor self.image_processor = self.feature_extractor
def save_pretrained(self, save_directory: str, safe_serialization: bool = False): def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
""" """
Save the pipeline's model and tokenizer. Save the pipeline's model and tokenizer.
@ -844,7 +844,7 @@ class Pipeline(_ScikitCompat):
save_directory (`str`): save_directory (`str`):
A path to the directory where to saved. It will be created if it doesn't exist. A path to the directory where to saved. It will be created if it doesn't exist.
safe_serialization (`str`): safe_serialization (`str`):
Whether to save the model using `safetensors` or the traditional way for PyTorch or Tensorflow Whether to save the model using `safetensors` or the traditional way for PyTorch or Tensorflow.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")

View File

@ -293,7 +293,7 @@ class TrainingArguments:
`save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained `save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained
alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two
checkpoints are saved: the last one and the best one (if they are different). checkpoints are saved: the last one and the best one (if they are different).
save_safetensors (`bool`, *optional*, defaults to `False`): save_safetensors (`bool`, *optional*, defaults to `True`):
Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of
default `torch.load` and `torch.save`. default `torch.load` and `torch.save`.
save_on_each_node (`bool`, *optional*, defaults to `False`): save_on_each_node (`bool`, *optional*, defaults to `False`):
@ -797,7 +797,7 @@ class TrainingArguments:
}, },
) )
save_safetensors: Optional[bool] = field( save_safetensors: Optional[bool] = field(
default=False, default=True,
metadata={ metadata={
"help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save." "help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
}, },

View File

@ -797,7 +797,7 @@ class PushToHubMixin:
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
max_shard_size: Optional[Union[int, str]] = "5GB", max_shard_size: Optional[Union[int, str]] = "5GB",
create_pr: bool = False, create_pr: bool = False,
safe_serialization: bool = False, safe_serialization: bool = True,
revision: str = None, revision: str = None,
commit_description: str = None, commit_description: str = None,
**deprecated_kwargs, **deprecated_kwargs,
@ -827,7 +827,7 @@ class PushToHubMixin:
Google Colab instances without any CPU OOM issues. Google Colab instances without any CPU OOM issues.
create_pr (`bool`, *optional*, defaults to `False`): create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit. Whether or not to create a PR with the uploaded files or directly commit.
safe_serialization (`bool`, *optional*, defaults to `False`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether or not to convert the model weights in safetensors format for safer serialization. Whether or not to convert the model weights in safetensors format for safer serialization.
revision (`str`, *optional*): revision (`str`, *optional*):
Branch to push the uploaded files to. Branch to push the uploaded files to.

View File

@ -211,6 +211,8 @@ class TFAutoModelTest(unittest.TestCase):
config = copy.deepcopy(model.config) config = copy.deepcopy(model.config)
config.architectures = ["FunnelBaseModel"] config.architectures = ["FunnelBaseModel"]
model = TFAutoModel.from_config(config) model = TFAutoModel.from_config(config)
model.build()
self.assertIsInstance(model, TFFunnelBaseModel) self.assertIsInstance(model, TFFunnelBaseModel)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
@ -245,7 +247,10 @@ class TFAutoModelTest(unittest.TestCase):
# Now that the config is registered, it can be used as any other config with the auto-API # Now that the config is registered, it can be used as any other config with the auto-API
tiny_config = BertModelTester(self).get_config() tiny_config = BertModelTester(self).get_config()
config = NewModelConfig(**tiny_config.to_dict()) config = NewModelConfig(**tiny_config.to_dict())
model = auto_class.from_config(config) model = auto_class.from_config(config)
model.build()
self.assertIsInstance(model, TFNewModel) self.assertIsInstance(model, TFNewModel)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:

View File

@ -525,7 +525,7 @@ class TFEncoderDecoderMixin:
# PT -> TF # PT -> TF
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname)
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
@ -542,7 +542,7 @@ class TFEncoderDecoderMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
@ -560,7 +560,8 @@ class TFEncoderDecoderMixin:
tf_model(**tf_inputs_dict) tf_model(**tf_inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tf_model.save_pretrained(tmpdirname) # TODO Matt: PT doesn't support loading TF safetensors - remove the arg and from_tf=True when it does
tf_model.save_pretrained(tmpdirname, safe_serialization=False)
pt_model = EncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True) pt_model = EncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
@ -1129,9 +1130,7 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2: with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2:
encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1) encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2) encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained( encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained(tmp_dirname_1, tmp_dirname_2)
tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True
)
logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
@ -1150,7 +1149,7 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
# TensorFlow => PyTorch # TensorFlow => PyTorch
with tempfile.TemporaryDirectory() as tmp_dirname: with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder_tf.save_pretrained(tmp_dirname) encoder_decoder_tf.save_pretrained(tmp_dirname, safe_serialization=False)
encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True) encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True)
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy())) max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))

View File

@ -458,7 +458,7 @@ class TFVisionEncoderDecoderMixin:
# PT -> TF # PT -> TF
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname)
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
@ -473,7 +473,7 @@ class TFVisionEncoderDecoderMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
@ -489,7 +489,7 @@ class TFVisionEncoderDecoderMixin:
tf_model(**tf_inputs_dict) tf_model(**tf_inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tf_model.save_pretrained(tmpdirname) tf_model.save_pretrained(tmpdirname, safe_serialization=False)
pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True) pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
@ -803,7 +803,7 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1) encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2) encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
encoder_decoder_tf = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained( encoder_decoder_tf = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True tmp_dirname_1, tmp_dirname_2
) )
logits_tf = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits logits_tf = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits
@ -814,7 +814,7 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
# Make sure `from_pretrained` following `save_pretrained` work and give the same result # Make sure `from_pretrained` following `save_pretrained` work and give the same result
# (See https://github.com/huggingface/transformers/pull/14016) # (See https://github.com/huggingface/transformers/pull/14016)
with tempfile.TemporaryDirectory() as tmp_dirname: with tempfile.TemporaryDirectory() as tmp_dirname:
encoder_decoder_tf.save_pretrained(tmp_dirname) encoder_decoder_tf.save_pretrained(tmp_dirname, safe_serialization=False)
encoder_decoder_tf = TFVisionEncoderDecoderModel.from_pretrained(tmp_dirname) encoder_decoder_tf = TFVisionEncoderDecoderModel.from_pretrained(tmp_dirname)
logits_tf_2 = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits logits_tf_2 = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits

View File

@ -91,6 +91,7 @@ if is_accelerate_available():
if is_torch_available(): if is_torch_available():
import torch import torch
from safetensors.torch import save_file as safe_save_file
from torch import nn from torch import nn
from transformers import MODEL_MAPPING, AdaptiveEmbedding from transformers import MODEL_MAPPING, AdaptiveEmbedding
@ -1751,8 +1752,8 @@ class ModelTesterMixin:
# We are nuking ALL weights on file, so every parameter should # We are nuking ALL weights on file, so every parameter should
# yell on load. We're going to detect if we yell too much, or too little. # yell on load. We're going to detect if we yell too much, or too little.
with open(os.path.join(tmp_dir, "pytorch_model.bin"), "wb") as f: placeholder_dict = {"tensor": torch.tensor([1, 2])}
torch.save({}, f) safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"})
model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True) model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
prefix = f"{model_reloaded.base_model_prefix}." prefix = f"{model_reloaded.base_model_prefix}."

View File

@ -16,11 +16,12 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
from huggingface_hub import HfFolder, delete_repo from huggingface_hub import HfFolder, delete_repo, snapshot_download
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import BertConfig, is_flax_available from transformers import BertConfig, BertModel, is_flax_available
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax, require_safetensors, require_torch
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
if is_flax_available(): if is_flax_available():
@ -184,3 +185,88 @@ class FlaxModelUtilsTest(unittest.TestCase):
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder) model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@require_safetensors
def test_safetensors_save_and_load(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
# No msgpack file, only a model.safetensors
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, FLAX_WEIGHTS_NAME)))
new_model = FlaxBertModel.from_pretrained(tmp_dir)
self.assertTrue(check_models_equal(model, new_model))
@require_flax
@require_torch
def test_safetensors_save_and_load_pt_to_flax(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
pt_model.save_pretrained(tmp_dir)
# Check we have a model.safetensors file
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
new_model = FlaxBertModel.from_pretrained(tmp_dir)
# Check models are equal
self.assertTrue(check_models_equal(model, new_model))
@require_safetensors
def test_safetensors_load_from_hub(self):
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
# Can load from the Flax-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_torch
@require_safetensors
def test_safetensors_load_from_hub_flax_and_pt(self):
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only", from_pt=True)
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
def test_safetensors_flax_from_flax(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = FlaxBertModel.from_pretrained(tmp_dir)
self.assertTrue(check_models_equal(model, new_model))
@require_safetensors
@require_torch
def test_safetensors_flax_from_torch(self):
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = FlaxBertModel.from_pretrained(tmp_dir)
self.assertTrue(check_models_equal(hub_model, new_model))
@require_safetensors
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_local(self):
with tempfile.TemporaryDirectory() as tmp_dir:
path = snapshot_download(
"hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded", cache_dir=tmp_dir
)
# This should not raise even if there are two types of sharded weights
FlaxBertModel.from_pretrained(path)
@require_safetensors
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_hub(self):
# This should not raise even if there are two types of sharded weights
# This should discard the safetensors weights in favor of the msgpack sharded weights
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")

View File

@ -24,7 +24,7 @@ import tempfile
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
from huggingface_hub import HfFolder, Repository, delete_repo from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download
from huggingface_hub.file_download import http_get from huggingface_hub.file_download import http_get
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
@ -39,6 +39,7 @@ from transformers.testing_utils import ( # noqa: F401
is_staging_test, is_staging_test,
require_safetensors, require_safetensors,
require_tf, require_tf,
require_torch,
slow, slow,
) )
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
@ -496,6 +497,44 @@ class TFModelUtilsTest(unittest.TestCase):
for p1, p2 in zip(safetensors_model.weights, tf_model.weights): for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_tf_from_tf(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = TFBertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
@is_pt_tf_cross_test
def test_safetensors_tf_from_torch(self):
hub_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = TFBertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(hub_model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_local(self):
with tempfile.TemporaryDirectory() as tmp_dir:
path = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", cache_dir=tmp_dir)
# This should not raise even if there are two types of sharded weights
TFBertModel.from_pretrained(path)
@require_safetensors
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
# This should not raise even if there are two types of sharded weights
# This should discard the safetensors weights in favor of the .h5 sharded weights
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
@require_tf @require_tf
@is_staging_test @is_staging_test

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import glob import glob
import json import json
import os import os
@ -42,7 +42,9 @@ from transformers.testing_utils import (
TestCasePlus, TestCasePlus,
is_staging_test, is_staging_test,
require_accelerate, require_accelerate,
require_flax,
require_safetensors, require_safetensors,
require_tf,
require_torch, require_torch,
require_torch_accelerator, require_torch_accelerator,
require_torch_multi_accelerator, require_torch_multi_accelerator,
@ -56,7 +58,7 @@ from transformers.utils import (
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
) )
from transformers.utils.import_utils import is_torchdynamo_available from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torchdynamo_available
sys.path.append(str(Path(__file__).parent.parent / "utils")) sys.path.append(str(Path(__file__).parent.parent / "utils"))
@ -66,6 +68,7 @@ from test_module.custom_configuration import CustomConfig, NoSuperInitConfig #
if is_torch_available(): if is_torch_available():
import torch import torch
from safetensors.torch import save_file as safe_save_file
from test_module.custom_modeling import CustomModel, NoSuperInitModel from test_module.custom_modeling import CustomModel, NoSuperInitModel
from torch import nn from torch import nn
@ -146,6 +149,13 @@ if is_torch_available():
self.decoder.weight = self.base.linear.weight self.decoder.weight = self.base.linear.weight
if is_flax_available():
from transformers import FlaxBertModel
if is_tf_available():
from transformers import TFBertModel
TINY_T5 = "patrickvonplaten/t5-tiny-random" TINY_T5 = "patrickvonplaten/t5-tiny-random"
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification" TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
@ -420,13 +430,13 @@ class ModelUtilsTest(TestCasePlus):
}, },
) )
def test_checkpoint_sharding_local(self): def test_checkpoint_sharding_local_bin(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint. # We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]: for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size) model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False)
# Get each shard file and its size # Get each shard file and its size
shard_to_size = {} shard_to_size = {}
@ -472,11 +482,11 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), ref_model.parameters()): for p1, p2 in zip(model.parameters(), ref_model.parameters()):
self.assertTrue(torch.allclose(p1, p2)) self.assertTrue(torch.allclose(p1, p2))
def test_checkpoint_variant_local(self): def test_checkpoint_variant_local_bin(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, variant="v2") model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"]) weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
@ -492,11 +502,11 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2)) self.assertTrue(torch.allclose(p1, p2))
def test_checkpoint_variant_local_sharded(self): def test_checkpoint_variant_local_sharded_bin(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB") model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=False)
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"]) weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
weights_index_file = os.path.join(tmp_dir, weights_index_name) weights_index_file = os.path.join(tmp_dir, weights_index_name)
@ -604,18 +614,18 @@ class ModelUtilsTest(TestCasePlus):
) )
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_checkpoint_variant_save_load(self): def test_checkpoint_variant_save_load_bin(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model = BertModel.from_pretrained( model = BertModel.from_pretrained(
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2" "hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
) )
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"]) weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
model.save_pretrained(tmp_dir, variant="v2") model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
# saving will create a variant checkpoint # saving will create a variant checkpoint
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name))) self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
model.save_pretrained(tmp_dir) model.save_pretrained(tmp_dir, safe_serialization=False)
# saving shouldn't delete variant checkpoints # saving shouldn't delete variant checkpoints
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"]) weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name))) self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
@ -874,7 +884,7 @@ class ModelUtilsTest(TestCasePlus):
def test_base_model_to_head_model_load(self): def test_base_model_to_head_model_load(self):
base_model = BaseModel(PretrainedConfig()) base_model = BaseModel(PretrainedConfig())
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
base_model.save_pretrained(tmp_dir) base_model.save_pretrained(tmp_dir, safe_serialization=False)
# Can load a base model in a model with head # Can load a base model in a model with head
model = ModelWithHead.from_pretrained(tmp_dir) model = ModelWithHead.from_pretrained(tmp_dir)
@ -886,7 +896,7 @@ class ModelUtilsTest(TestCasePlus):
head_state_dict = model.state_dict() head_state_dict = model.state_dict()
base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"] base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"] base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
torch.save(base_state_dict, os.path.join(tmp_dir, WEIGHTS_NAME)) safe_save_file(base_state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "The state dictionary of the model you are trying to load is corrupted." ValueError, "The state dictionary of the model you are trying to load is corrupted."
@ -934,8 +944,8 @@ class ModelUtilsTest(TestCasePlus):
# Loading the model with the same class, we do get a warning for unexpected weights # Loading the model with the same class, we do get a warning for unexpected weights
state_dict = model.state_dict() state_dict = model.state_dict()
state_dict["added_key"] = state_dict["linear.weight"] state_dict["added_key"] = copy.deepcopy(state_dict["linear.weight"])
torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME)) safe_save_file(state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
_, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True) _, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out) self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
@ -1072,6 +1082,54 @@ class ModelUtilsTest(TestCasePlus):
) )
self.assertEqual(model.generation_config.transformers_version, "foo") self.assertEqual(model.generation_config.transformers_version, "foo")
@require_safetensors
def test_safetensors_torch_from_torch(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_safetensors
@require_flax
def test_safetensors_torch_from_flax(self):
hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(hub_model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_tf
@require_safetensors
def test_safetensors_torch_from_tf(self):
hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(hub_model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_safetensors
def test_safetensors_torch_from_torch_sharded(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_torch @require_torch
@is_staging_test @is_staging_test

View File

@ -403,7 +403,7 @@ if is_torch_available():
class TrainerIntegrationCommon: class TrainerIntegrationCommon:
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=False): def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True):
weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"] file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
if is_pretrained: if is_pretrained:
@ -415,7 +415,7 @@ class TrainerIntegrationCommon:
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename))) self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
def check_best_model_has_been_loaded( def check_best_model_has_been_loaded(
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=False self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=True
): ):
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}") checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history
@ -456,7 +456,7 @@ class TrainerIntegrationCommon:
_ = log1.pop(key, None) _ = log1.pop(key, None)
self.assertEqual(log, log1) self.assertEqual(log, log1)
def convert_to_sharded_checkpoint(self, folder, save_safe=False, load_safe=False): def convert_to_sharded_checkpoint(self, folder, save_safe=True, load_safe=True):
# Converts a checkpoint of a regression model to a sharded checkpoint. # Converts a checkpoint of a regression model to a sharded checkpoint.
if load_safe: if load_safe:
loader = safetensors.torch.load_file loader = safetensors.torch.load_file

View File

@ -43,7 +43,6 @@ class CLITest(unittest.TestCase):
shutil.rmtree("/tmp/hf-internal-testing/tiny-random-gptj", ignore_errors=True) # cleans potential past runs shutil.rmtree("/tmp/hf-internal-testing/tiny-random-gptj", ignore_errors=True) # cleans potential past runs
transformers.commands.transformers_cli.main() transformers.commands.transformers_cli.main()
# The original repo has no TF weights -- if they exist, they were created by the CLI
self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5")) self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5"))
@require_torch @require_torch