Safetensors serialization by default (#27064)
* Safetensors serialization by default * First pass on the tests * Second pass on the tests * Third pass on the tests * Fix TF weight loading from TF-format safetensors * Specific encoder-decoder fixes for weight crossloading * Add VisionEncoderDecoder fixes for TF too * Change filename test for pt-to-tf * One missing fix for TFVisionEncoderDecoder * Fix the other crossload test * Support for flax + updated tests * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Sanchit's comments * Sanchit's comments 2 * Nico's comments * Fix tests * cleanup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
25e6e9418c
commit
113ebf80ac
|
@ -27,9 +27,15 @@ from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
from . import is_safetensors_available
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_safetensors_available():
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.flax import load_file as safe_load_file
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,7 +62,13 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
||||||
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
||||||
logger.info(f"Loading PyTorch weights from {pt_path}")
|
logger.info(f"Loading PyTorch weights from {pt_path}")
|
||||||
|
|
||||||
pt_state_dict = torch.load(pt_path, map_location="cpu")
|
if pt_path.endswith(".safetensors"):
|
||||||
|
pt_state_dict = {}
|
||||||
|
with safe_open(pt_path, framework="pt") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
pt_state_dict[k] = f.get_tensor(k)
|
||||||
|
else:
|
||||||
|
pt_state_dict = torch.load(pt_path, map_location="cpu")
|
||||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||||
|
|
||||||
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
||||||
|
@ -319,11 +331,15 @@ def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
|
||||||
flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)
|
flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)
|
||||||
|
|
||||||
# load flax weight dict
|
# load flax weight dict
|
||||||
with open(flax_checkpoint_path, "rb") as state_f:
|
if flax_checkpoint_path.endswith(".safetensors"):
|
||||||
try:
|
flax_state_dict = safe_load_file(flax_checkpoint_path)
|
||||||
flax_state_dict = from_bytes(flax_cls, state_f.read())
|
flax_state_dict = unflatten_dict(flax_state_dict, sep=".")
|
||||||
except UnpicklingError:
|
else:
|
||||||
raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
|
with open(flax_checkpoint_path, "rb") as state_f:
|
||||||
|
try:
|
||||||
|
flax_state_dict = from_bytes(flax_cls, state_f.read())
|
||||||
|
except UnpicklingError:
|
||||||
|
raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
|
||||||
|
|
||||||
return load_flax_weights_in_pytorch_model(model, flax_state_dict)
|
return load_flax_weights_in_pytorch_model(model, flax_state_dict)
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,8 @@ from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_d
|
||||||
from .utils import (
|
from .utils import (
|
||||||
FLAX_WEIGHTS_INDEX_NAME,
|
FLAX_WEIGHTS_INDEX_NAME,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
|
@ -54,8 +56,14 @@ from .utils import (
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||||
|
from .utils.import_utils import is_safetensors_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_safetensors_available():
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.flax import load_file as safe_load_file
|
||||||
|
from safetensors.flax import save_file as safe_save_file
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -422,6 +430,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
```"""
|
```"""
|
||||||
return self._cast_floating_to(params, jnp.float16, mask)
|
return self._cast_floating_to(params, jnp.float16, mask)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_flax_weights(cls, resolved_archive_file):
|
||||||
|
try:
|
||||||
|
if resolved_archive_file.endswith(".safetensors"):
|
||||||
|
state = safe_load_file(resolved_archive_file)
|
||||||
|
state = unflatten_dict(state, sep=".")
|
||||||
|
else:
|
||||||
|
with open(resolved_archive_file, "rb") as state_f:
|
||||||
|
state = from_bytes(cls, state_f.read())
|
||||||
|
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
||||||
|
try:
|
||||||
|
with open(resolved_archive_file) as f:
|
||||||
|
if f.read().startswith("version"):
|
||||||
|
raise OSError(
|
||||||
|
"You seem to have cloned a repository without having git-lfs installed. Please"
|
||||||
|
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
||||||
|
" folder you cloned."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError from e
|
||||||
|
except (UnicodeDecodeError, ValueError):
|
||||||
|
raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ")
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_flax_sharded_weights(cls, shard_files):
|
def load_flax_sharded_weights(cls, shard_files):
|
||||||
"""
|
"""
|
||||||
|
@ -688,7 +721,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
if is_safetensors_available() and os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||||
|
):
|
||||||
|
# Load from a safetensors checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||||
|
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
||||||
elif from_pt and os.path.isfile(
|
elif from_pt and os.path.isfile(
|
||||||
|
@ -705,6 +743,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
# At this stage we don't have a weight file so we will raise an error.
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
|
elif is_safetensors_available() and os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
):
|
||||||
|
# Load from a sharded safetensors checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
is_sharded = True
|
||||||
|
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||||
|
@ -723,7 +768,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
filename = pretrained_model_name_or_path
|
filename = pretrained_model_name_or_path
|
||||||
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
||||||
else:
|
else:
|
||||||
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
|
if from_pt:
|
||||||
|
filename = WEIGHTS_NAME
|
||||||
|
elif is_safetensors_available():
|
||||||
|
filename = SAFE_WEIGHTS_NAME
|
||||||
|
else:
|
||||||
|
filename = FLAX_WEIGHTS_NAME
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
cached_file_kwargs = {
|
cached_file_kwargs = {
|
||||||
|
@ -741,8 +792,15 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
}
|
}
|
||||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||||
|
|
||||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
|
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
||||||
# result when internet is up, the repo and revision exist, but the file does not.
|
# result when internet is up, the repo and revision exist, but the file does not.
|
||||||
|
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
|
||||||
|
# Did not find the safetensors file, let's fallback to Flax.
|
||||||
|
# No support for sharded safetensors yet, so we'll raise an error if that's all we find.
|
||||||
|
filename = FLAX_WEIGHTS_NAME
|
||||||
|
resolved_archive_file = cached_file(
|
||||||
|
pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **cached_file_kwargs
|
||||||
|
)
|
||||||
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
|
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
|
||||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
resolved_archive_file = cached_file(
|
resolved_archive_file = cached_file(
|
||||||
|
@ -751,21 +809,26 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
if resolved_archive_file is not None:
|
if resolved_archive_file is not None:
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
|
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
|
||||||
elif resolved_archive_file is None and from_pt:
|
if resolved_archive_file is None and from_pt:
|
||||||
resolved_archive_file = cached_file(
|
resolved_archive_file = cached_file(
|
||||||
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||||
)
|
)
|
||||||
if resolved_archive_file is not None:
|
if resolved_archive_file is not None:
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
if resolved_archive_file is None:
|
if resolved_archive_file is None:
|
||||||
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
# Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error
|
||||||
# message.
|
# message.
|
||||||
has_file_kwargs = {
|
has_file_kwargs = {
|
||||||
"revision": revision,
|
"revision": revision,
|
||||||
"proxies": proxies,
|
"proxies": proxies,
|
||||||
"token": token,
|
"token": token,
|
||||||
}
|
}
|
||||||
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
|
||||||
|
is_sharded = True
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Support for sharded checkpoints using safetensors is coming soon!"
|
||||||
|
)
|
||||||
|
elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
|
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
|
||||||
|
@ -798,6 +861,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
if is_local:
|
if is_local:
|
||||||
logger.info(f"loading weights file {archive_file}")
|
logger.info(f"loading weights file {archive_file}")
|
||||||
resolved_archive_file = archive_file
|
resolved_archive_file = archive_file
|
||||||
|
filename = resolved_archive_file.split(os.path.sep)[-1]
|
||||||
else:
|
else:
|
||||||
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
|
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
|
||||||
else:
|
else:
|
||||||
|
@ -821,31 +885,27 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
_commit_hash=commit_hash,
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
safetensors_from_pt = False
|
||||||
|
if filename == SAFE_WEIGHTS_NAME:
|
||||||
|
with safe_open(resolved_archive_file, framework="flax") as f:
|
||||||
|
safetensors_metadata = f.metadata()
|
||||||
|
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
|
||||||
|
raise OSError(
|
||||||
|
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
|
||||||
|
" Make sure you save your model with the `save_pretrained` method."
|
||||||
|
)
|
||||||
|
safetensors_from_pt = safetensors_metadata.get("format") == "pt"
|
||||||
|
|
||||||
# init random models
|
# init random models
|
||||||
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
|
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
|
||||||
|
|
||||||
if from_pt:
|
if from_pt or safetensors_from_pt:
|
||||||
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
|
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
|
||||||
else:
|
else:
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
state = cls.load_flax_sharded_weights(resolved_archive_file)
|
state = cls.load_flax_sharded_weights(resolved_archive_file)
|
||||||
else:
|
else:
|
||||||
try:
|
state = cls.load_flax_weights(resolved_archive_file)
|
||||||
with open(resolved_archive_file, "rb") as state_f:
|
|
||||||
state = from_bytes(cls, state_f.read())
|
|
||||||
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
|
||||||
try:
|
|
||||||
with open(resolved_archive_file) as f:
|
|
||||||
if f.read().startswith("version"):
|
|
||||||
raise OSError(
|
|
||||||
"You seem to have cloned a repository without having git-lfs installed. Please"
|
|
||||||
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
|
||||||
" folder you cloned."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError from e
|
|
||||||
except (UnicodeDecodeError, ValueError):
|
|
||||||
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
|
|
||||||
# make sure all arrays are stored as jnp.arrays
|
# make sure all arrays are stored as jnp.arrays
|
||||||
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
||||||
# https://github.com/google/flax/issues/1261
|
# https://github.com/google/flax/issues/1261
|
||||||
|
@ -1030,6 +1090,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
max_shard_size="10GB",
|
max_shard_size="10GB",
|
||||||
token: Optional[Union[str, bool]] = None,
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
safe_serialization: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -1059,6 +1120,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
|
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||||
kwargs (`Dict[str, Any]`, *optional*):
|
kwargs (`Dict[str, Any]`, *optional*):
|
||||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||||
|
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to save the model using `safetensors` or through msgpack.
|
||||||
"""
|
"""
|
||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
|
||||||
|
@ -1103,24 +1166,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
self.generation_config.save_pretrained(save_directory)
|
self.generation_config.save_pretrained(save_directory)
|
||||||
|
|
||||||
# save model
|
# save model
|
||||||
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
|
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME
|
||||||
|
output_model_file = os.path.join(save_directory, weights_name)
|
||||||
|
|
||||||
shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size)
|
shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size)
|
||||||
# Clean the folder from a previous save
|
# Clean the folder from a previous save
|
||||||
for filename in os.listdir(save_directory):
|
for filename in os.listdir(save_directory):
|
||||||
full_filename = os.path.join(save_directory, filename)
|
full_filename = os.path.join(save_directory, filename)
|
||||||
|
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
|
||||||
if (
|
if (
|
||||||
filename.startswith(FLAX_WEIGHTS_NAME[:-4])
|
filename.startswith(weights_no_suffix)
|
||||||
and os.path.isfile(full_filename)
|
and os.path.isfile(full_filename)
|
||||||
and filename not in shards.keys()
|
and filename not in shards.keys()
|
||||||
):
|
):
|
||||||
os.remove(full_filename)
|
os.remove(full_filename)
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
with open(output_model_file, "wb") as f:
|
if safe_serialization:
|
||||||
params = params if params is not None else self.params
|
params = params if params is not None else self.params
|
||||||
model_bytes = to_bytes(params)
|
flat_dict = flatten_dict(params, sep=".")
|
||||||
f.write(model_bytes)
|
safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"})
|
||||||
|
else:
|
||||||
|
with open(output_model_file, "wb") as f:
|
||||||
|
params = params if params is not None else self.params
|
||||||
|
model_bytes = to_bytes(params)
|
||||||
|
f.write(model_bytes)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME)
|
save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME)
|
||||||
|
|
|
@ -626,11 +626,13 @@ def dtype_byte_size(dtype):
|
||||||
return bit_size // 8
|
return bit_size // 8
|
||||||
|
|
||||||
|
|
||||||
def format_weight_name(name, _prefix=None):
|
def strip_model_name_and_prefix(name, _prefix=None):
|
||||||
|
if _prefix is not None and name.startswith(_prefix):
|
||||||
|
name = name[len(_prefix) :]
|
||||||
|
if name.startswith("/"):
|
||||||
|
name = name[1:]
|
||||||
if "model." not in name and len(name.split("/")) > 1:
|
if "model." not in name and len(name.split("/")) > 1:
|
||||||
name = "/".join(name.split("/")[1:])
|
name = "/".join(name.split("/")[1:])
|
||||||
if _prefix is not None:
|
|
||||||
name = _prefix + "/" + name
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
@ -986,7 +988,7 @@ def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismat
|
||||||
# Read the safetensors file
|
# Read the safetensors file
|
||||||
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
|
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
|
||||||
mismatched_layers = []
|
mismatched_layers = []
|
||||||
weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights]
|
weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights]
|
||||||
loaded_weight_names = list(safetensors_archive.keys())
|
loaded_weight_names = list(safetensors_archive.keys())
|
||||||
# Find the missing layers from the high level list of layers
|
# Find the missing layers from the high level list of layers
|
||||||
missing_layers = list(set(weight_names) - set(loaded_weight_names))
|
missing_layers = list(set(weight_names) - set(loaded_weight_names))
|
||||||
|
@ -994,7 +996,7 @@ def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismat
|
||||||
unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
|
unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
|
||||||
|
|
||||||
for weight in model.weights:
|
for weight in model.weights:
|
||||||
weight_name = format_weight_name(weight.name, _prefix=_prefix)
|
weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix)
|
||||||
if weight_name in loaded_weight_names:
|
if weight_name in loaded_weight_names:
|
||||||
weight_value = safetensors_archive.get_tensor(weight_name)
|
weight_value = safetensors_archive.get_tensor(weight_name)
|
||||||
# Check if the shape of the current weight and the one from the H5 file are different
|
# Check if the shape of the current weight and the one from the H5 file are different
|
||||||
|
@ -1003,7 +1005,7 @@ def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismat
|
||||||
# If the two shapes are not compatible we raise an issue
|
# If the two shapes are not compatible we raise an issue
|
||||||
try:
|
try:
|
||||||
weight_value = tf.reshape(weight_value, K.int_shape(weight))
|
weight_value = tf.reshape(weight_value, K.int_shape(weight))
|
||||||
except ValueError as e:
|
except (ValueError, tf.errors.InvalidArgumentError) as e:
|
||||||
if ignore_mismatched_sizes:
|
if ignore_mismatched_sizes:
|
||||||
mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
|
mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
|
||||||
continue
|
continue
|
||||||
|
@ -2367,7 +2369,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||||
create_pr (`bool`, *optional*, defaults to `False`):
|
create_pr (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to create a PR with the uploaded files or directly commit.
|
Whether or not to create a PR with the uploaded files or directly commit.
|
||||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
Whether to save the model using `safetensors` or the traditional TensorFlow way (that uses `h5`).
|
||||||
token (`str` or `bool`, *optional*):
|
token (`str` or `bool`, *optional*):
|
||||||
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
||||||
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
|
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||||
|
@ -2457,7 +2459,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
if safe_serialization:
|
if safe_serialization:
|
||||||
state_dict = {format_weight_name(w.name): w.value() for w in self.weights}
|
state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in self.weights}
|
||||||
safe_save_file(state_dict, output_model_file, metadata={"format": "tf"})
|
safe_save_file(state_dict, output_model_file, metadata={"format": "tf"})
|
||||||
else:
|
else:
|
||||||
self.save_weights(output_model_file)
|
self.save_weights(output_model_file)
|
||||||
|
@ -2718,13 +2720,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||||
):
|
):
|
||||||
# Load from a safetensors checkpoint
|
# Load from a safetensors checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||||
elif is_safetensors_available() and os.path.isfile(
|
|
||||||
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
|
||||||
):
|
|
||||||
# Load from a sharded safetensors checkpoint
|
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
|
||||||
is_sharded = True
|
|
||||||
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
|
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||||
# Load from a TF 2.0 checkpoint
|
# Load from a TF 2.0 checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||||
|
@ -2732,6 +2727,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||||
# Load from a sharded TF 2.0 checkpoint
|
# Load from a sharded TF 2.0 checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
|
elif is_safetensors_available() and os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
):
|
||||||
|
# Load from a sharded safetensors checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
is_sharded = True
|
||||||
|
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
|
||||||
# At this stage we don't have a weight file so we will raise an error.
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||||
|
@ -2784,21 +2786,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
||||||
# result when internet is up, the repo and revision exist, but the file does not.
|
# result when internet is up, the repo and revision exist, but the file does not.
|
||||||
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
|
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
|
||||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
# Did not find the safetensors file, let's fallback to TF.
|
||||||
|
# No support for sharded safetensors yet, so we'll raise an error if that's all we find.
|
||||||
|
filename = TF2_WEIGHTS_NAME
|
||||||
resolved_archive_file = cached_file(
|
resolved_archive_file = cached_file(
|
||||||
pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs
|
||||||
)
|
)
|
||||||
if resolved_archive_file is not None:
|
|
||||||
is_sharded = True
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Support for sharded checkpoints using safetensors is coming soon!"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# This repo has no safetensors file of any kind, we switch to TensorFlow.
|
|
||||||
filename = TF2_WEIGHTS_NAME
|
|
||||||
resolved_archive_file = cached_file(
|
|
||||||
pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs
|
|
||||||
)
|
|
||||||
if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
|
if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
|
||||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
resolved_archive_file = cached_file(
|
resolved_archive_file = cached_file(
|
||||||
|
@ -2821,7 +2814,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||||
"proxies": proxies,
|
"proxies": proxies,
|
||||||
"token": token,
|
"token": token,
|
||||||
}
|
}
|
||||||
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
|
||||||
|
is_sharded = True
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Support for sharded checkpoints using safetensors is coming soon!"
|
||||||
|
)
|
||||||
|
elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
|
f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
|
||||||
|
@ -2928,6 +2926,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||||
output_loading_info=output_loading_info,
|
output_loading_info=output_loading_info,
|
||||||
_prefix=load_weight_prefix,
|
_prefix=load_weight_prefix,
|
||||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
|
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||||
|
|
|
@ -470,10 +470,6 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||||
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
||||||
"you save your model with the `save_pretrained` method."
|
"you save your model with the `save_pretrained` method."
|
||||||
)
|
)
|
||||||
elif metadata["format"] != "pt":
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
|
|
||||||
)
|
|
||||||
return safe_load_file(checkpoint_file)
|
return safe_load_file(checkpoint_file)
|
||||||
try:
|
try:
|
||||||
if (
|
if (
|
||||||
|
@ -1934,7 +1930,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
save_function: Callable = torch.save,
|
save_function: Callable = torch.save,
|
||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
max_shard_size: Union[int, str] = "5GB",
|
max_shard_size: Union[int, str] = "5GB",
|
||||||
safe_serialization: bool = False,
|
safe_serialization: bool = True,
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
token: Optional[Union[str, bool]] = None,
|
token: Optional[Union[str, bool]] = None,
|
||||||
save_peft_format: bool = True,
|
save_peft_format: bool = True,
|
||||||
|
@ -1975,7 +1971,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||||
variant (`str`, *optional*):
|
variant (`str`, *optional*):
|
||||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||||
|
@ -2736,8 +2732,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
" sure the weights are in PyTorch format."
|
" sure the weights are in PyTorch format."
|
||||||
)
|
)
|
||||||
|
|
||||||
from_pt = not (from_tf | from_flax)
|
|
||||||
|
|
||||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||||
if from_pipeline is not None:
|
if from_pipeline is not None:
|
||||||
user_agent["using_pipeline"] = from_pipeline
|
user_agent["using_pipeline"] = from_pipeline
|
||||||
|
@ -3103,6 +3097,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
_commit_hash=commit_hash,
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
is_safetensors_available()
|
||||||
|
and isinstance(resolved_archive_file, str)
|
||||||
|
and resolved_archive_file.endswith(".safetensors")
|
||||||
|
):
|
||||||
|
with safe_open(resolved_archive_file, framework="pt") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
|
||||||
|
if metadata.get("format") == "pt":
|
||||||
|
pass
|
||||||
|
elif metadata.get("format") == "tf":
|
||||||
|
from_tf = True
|
||||||
|
logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.")
|
||||||
|
elif metadata.get("format") == "flax":
|
||||||
|
from_flax = True
|
||||||
|
logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from_pt = not (from_tf | from_flax)
|
||||||
|
|
||||||
# load pt weights early so that we know which dtype to init the model under
|
# load pt weights early so that we know which dtype to init the model under
|
||||||
if from_pt:
|
if from_pt:
|
||||||
if not is_sharded and state_dict is None:
|
if not is_sharded and state_dict is None:
|
||||||
|
@ -3391,7 +3408,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||||
# restore default dtype
|
# restore default dtype
|
||||||
if dtype_orig is not None:
|
if dtype_orig is not None:
|
||||||
torch.set_default_dtype(dtype_orig)
|
torch.set_default_dtype(dtype_orig)
|
||||||
|
|
||||||
(
|
(
|
||||||
model,
|
model,
|
||||||
missing_keys,
|
missing_keys,
|
||||||
|
|
|
@ -366,8 +366,8 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||||
model.config = config
|
model.config = config
|
||||||
|
|
||||||
if hasattr(model, "enc_to_dec_proj"):
|
if hasattr(model, "enc_to_dec_proj"):
|
||||||
model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight
|
model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous()
|
||||||
model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias
|
model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -306,17 +306,21 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||||
# here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
|
# here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
|
||||||
# not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
|
# not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
|
||||||
|
|
||||||
if kwargs.get("from_pt", False):
|
# This override is only needed in the case where we're crossloading weights from PT. However, since weights are
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
# often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file.
|
||||||
encoder_model_type = config.encoder.model_type
|
# Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it
|
||||||
|
# or not.
|
||||||
|
|
||||||
def tf_to_pt_weight_rename(tf_weight):
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
||||||
if "encoder" in tf_weight and "decoder" not in tf_weight:
|
encoder_model_type = config.encoder.model_type
|
||||||
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
|
|
||||||
else:
|
|
||||||
return tf_weight
|
|
||||||
|
|
||||||
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
|
def tf_to_pt_weight_rename(tf_weight):
|
||||||
|
if "encoder" in tf_weight and "decoder" not in tf_weight:
|
||||||
|
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
|
||||||
|
else:
|
||||||
|
return tf_weight
|
||||||
|
|
||||||
|
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
|
||||||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -322,17 +322,21 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
||||||
# here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
|
# here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
|
||||||
# not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
|
# not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
|
||||||
|
|
||||||
if kwargs.get("from_pt", False):
|
# This override is only needed in the case where we're crossloading weights from PT. However, since weights are
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
# often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file.
|
||||||
encoder_model_type = config.encoder.model_type
|
# Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it
|
||||||
|
# or not.
|
||||||
|
|
||||||
def tf_to_pt_weight_rename(tf_weight):
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
||||||
if "encoder" in tf_weight and "decoder" not in tf_weight:
|
encoder_model_type = config.encoder.model_type
|
||||||
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
|
|
||||||
else:
|
|
||||||
return tf_weight
|
|
||||||
|
|
||||||
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
|
def tf_to_pt_weight_rename(tf_weight):
|
||||||
|
if "encoder" in tf_weight and "decoder" not in tf_weight:
|
||||||
|
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight)
|
||||||
|
else:
|
||||||
|
return tf_weight
|
||||||
|
|
||||||
|
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
|
||||||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -342,8 +342,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
||||||
model.config = config
|
model.config = config
|
||||||
|
|
||||||
if hasattr(model, "enc_to_dec_proj"):
|
if hasattr(model, "enc_to_dec_proj"):
|
||||||
model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight
|
model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous()
|
||||||
model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias
|
model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -836,7 +836,7 @@ class Pipeline(_ScikitCompat):
|
||||||
# then we should keep working
|
# then we should keep working
|
||||||
self.image_processor = self.feature_extractor
|
self.image_processor = self.feature_extractor
|
||||||
|
|
||||||
def save_pretrained(self, save_directory: str, safe_serialization: bool = False):
|
def save_pretrained(self, save_directory: str, safe_serialization: bool = True):
|
||||||
"""
|
"""
|
||||||
Save the pipeline's model and tokenizer.
|
Save the pipeline's model and tokenizer.
|
||||||
|
|
||||||
|
@ -844,7 +844,7 @@ class Pipeline(_ScikitCompat):
|
||||||
save_directory (`str`):
|
save_directory (`str`):
|
||||||
A path to the directory where to saved. It will be created if it doesn't exist.
|
A path to the directory where to saved. It will be created if it doesn't exist.
|
||||||
safe_serialization (`str`):
|
safe_serialization (`str`):
|
||||||
Whether to save the model using `safetensors` or the traditional way for PyTorch or Tensorflow
|
Whether to save the model using `safetensors` or the traditional way for PyTorch or Tensorflow.
|
||||||
"""
|
"""
|
||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
|
|
|
@ -293,7 +293,7 @@ class TrainingArguments:
|
||||||
`save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained
|
`save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained
|
||||||
alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two
|
alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two
|
||||||
checkpoints are saved: the last one and the best one (if they are different).
|
checkpoints are saved: the last one and the best one (if they are different).
|
||||||
save_safetensors (`bool`, *optional*, defaults to `False`):
|
save_safetensors (`bool`, *optional*, defaults to `True`):
|
||||||
Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of
|
Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of
|
||||||
default `torch.load` and `torch.save`.
|
default `torch.load` and `torch.save`.
|
||||||
save_on_each_node (`bool`, *optional*, defaults to `False`):
|
save_on_each_node (`bool`, *optional*, defaults to `False`):
|
||||||
|
@ -797,7 +797,7 @@ class TrainingArguments:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
save_safetensors: Optional[bool] = field(
|
save_safetensors: Optional[bool] = field(
|
||||||
default=False,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
|
"help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
|
||||||
},
|
},
|
||||||
|
|
|
@ -797,7 +797,7 @@ class PushToHubMixin:
|
||||||
token: Optional[Union[bool, str]] = None,
|
token: Optional[Union[bool, str]] = None,
|
||||||
max_shard_size: Optional[Union[int, str]] = "5GB",
|
max_shard_size: Optional[Union[int, str]] = "5GB",
|
||||||
create_pr: bool = False,
|
create_pr: bool = False,
|
||||||
safe_serialization: bool = False,
|
safe_serialization: bool = True,
|
||||||
revision: str = None,
|
revision: str = None,
|
||||||
commit_description: str = None,
|
commit_description: str = None,
|
||||||
**deprecated_kwargs,
|
**deprecated_kwargs,
|
||||||
|
@ -827,7 +827,7 @@ class PushToHubMixin:
|
||||||
Google Colab instances without any CPU OOM issues.
|
Google Colab instances without any CPU OOM issues.
|
||||||
create_pr (`bool`, *optional*, defaults to `False`):
|
create_pr (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to create a PR with the uploaded files or directly commit.
|
Whether or not to create a PR with the uploaded files or directly commit.
|
||||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to convert the model weights in safetensors format for safer serialization.
|
Whether or not to convert the model weights in safetensors format for safer serialization.
|
||||||
revision (`str`, *optional*):
|
revision (`str`, *optional*):
|
||||||
Branch to push the uploaded files to.
|
Branch to push the uploaded files to.
|
||||||
|
|
|
@ -211,6 +211,8 @@ class TFAutoModelTest(unittest.TestCase):
|
||||||
config = copy.deepcopy(model.config)
|
config = copy.deepcopy(model.config)
|
||||||
config.architectures = ["FunnelBaseModel"]
|
config.architectures = ["FunnelBaseModel"]
|
||||||
model = TFAutoModel.from_config(config)
|
model = TFAutoModel.from_config(config)
|
||||||
|
model.build()
|
||||||
|
|
||||||
self.assertIsInstance(model, TFFunnelBaseModel)
|
self.assertIsInstance(model, TFFunnelBaseModel)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
@ -245,7 +247,10 @@ class TFAutoModelTest(unittest.TestCase):
|
||||||
# Now that the config is registered, it can be used as any other config with the auto-API
|
# Now that the config is registered, it can be used as any other config with the auto-API
|
||||||
tiny_config = BertModelTester(self).get_config()
|
tiny_config = BertModelTester(self).get_config()
|
||||||
config = NewModelConfig(**tiny_config.to_dict())
|
config = NewModelConfig(**tiny_config.to_dict())
|
||||||
|
|
||||||
model = auto_class.from_config(config)
|
model = auto_class.from_config(config)
|
||||||
|
model.build()
|
||||||
|
|
||||||
self.assertIsInstance(model, TFNewModel)
|
self.assertIsInstance(model, TFNewModel)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
|
|
@ -525,7 +525,7 @@ class TFEncoderDecoderMixin:
|
||||||
# PT -> TF
|
# PT -> TF
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||||
|
|
||||||
|
@ -542,7 +542,7 @@ class TFEncoderDecoderMixin:
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
||||||
|
|
||||||
|
@ -560,7 +560,8 @@ class TFEncoderDecoderMixin:
|
||||||
tf_model(**tf_inputs_dict)
|
tf_model(**tf_inputs_dict)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
tf_model.save_pretrained(tmpdirname)
|
# TODO Matt: PT doesn't support loading TF safetensors - remove the arg and from_tf=True when it does
|
||||||
|
tf_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||||
pt_model = EncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
|
pt_model = EncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
|
||||||
|
|
||||||
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
||||||
|
@ -1129,9 +1130,7 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||||
with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2:
|
with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2:
|
||||||
encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
|
encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
|
||||||
encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
|
encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
|
||||||
encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
|
encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained(tmp_dirname_1, tmp_dirname_2)
|
||||||
tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True
|
|
||||||
)
|
|
||||||
|
|
||||||
logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
|
logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
|
||||||
|
|
||||||
|
@ -1150,7 +1149,7 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||||
|
|
||||||
# TensorFlow => PyTorch
|
# TensorFlow => PyTorch
|
||||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
encoder_decoder_tf.save_pretrained(tmp_dirname)
|
encoder_decoder_tf.save_pretrained(tmp_dirname, safe_serialization=False)
|
||||||
encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True)
|
encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True)
|
||||||
|
|
||||||
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
|
max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy()))
|
||||||
|
|
|
@ -458,7 +458,7 @@ class TFVisionEncoderDecoderMixin:
|
||||||
# PT -> TF
|
# PT -> TF
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||||
|
|
||||||
|
@ -473,7 +473,7 @@ class TFVisionEncoderDecoderMixin:
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
||||||
|
|
||||||
|
@ -489,7 +489,7 @@ class TFVisionEncoderDecoderMixin:
|
||||||
tf_model(**tf_inputs_dict)
|
tf_model(**tf_inputs_dict)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
tf_model.save_pretrained(tmpdirname)
|
tf_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||||
pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
|
pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
|
||||||
|
|
||||||
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)
|
||||||
|
@ -803,7 +803,7 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||||
encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
|
encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1)
|
||||||
encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
|
encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2)
|
||||||
encoder_decoder_tf = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
encoder_decoder_tf = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True
|
tmp_dirname_1, tmp_dirname_2
|
||||||
)
|
)
|
||||||
|
|
||||||
logits_tf = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits
|
logits_tf = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits
|
||||||
|
@ -814,7 +814,7 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||||
# Make sure `from_pretrained` following `save_pretrained` work and give the same result
|
# Make sure `from_pretrained` following `save_pretrained` work and give the same result
|
||||||
# (See https://github.com/huggingface/transformers/pull/14016)
|
# (See https://github.com/huggingface/transformers/pull/14016)
|
||||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
encoder_decoder_tf.save_pretrained(tmp_dirname)
|
encoder_decoder_tf.save_pretrained(tmp_dirname, safe_serialization=False)
|
||||||
encoder_decoder_tf = TFVisionEncoderDecoderModel.from_pretrained(tmp_dirname)
|
encoder_decoder_tf = TFVisionEncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||||
|
|
||||||
logits_tf_2 = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits
|
logits_tf_2 = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits
|
||||||
|
|
|
@ -91,6 +91,7 @@ if is_accelerate_available():
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
from safetensors.torch import save_file as safe_save_file
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
||||||
|
@ -1751,8 +1752,8 @@ class ModelTesterMixin:
|
||||||
|
|
||||||
# We are nuking ALL weights on file, so every parameter should
|
# We are nuking ALL weights on file, so every parameter should
|
||||||
# yell on load. We're going to detect if we yell too much, or too little.
|
# yell on load. We're going to detect if we yell too much, or too little.
|
||||||
with open(os.path.join(tmp_dir, "pytorch_model.bin"), "wb") as f:
|
placeholder_dict = {"tensor": torch.tensor([1, 2])}
|
||||||
torch.save({}, f)
|
safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"})
|
||||||
model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
|
model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
|
||||||
|
|
||||||
prefix = f"{model_reloaded.base_model_prefix}."
|
prefix = f"{model_reloaded.base_model_prefix}."
|
||||||
|
|
|
@ -16,11 +16,12 @@ import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import HfFolder, delete_repo
|
from huggingface_hub import HfFolder, delete_repo, snapshot_download
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
from transformers import BertConfig, is_flax_available
|
from transformers import BertConfig, BertModel, is_flax_available
|
||||||
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax
|
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax, require_safetensors, require_torch
|
||||||
|
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
|
@ -184,3 +185,88 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||||
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||||
|
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_save_and_load(self):
|
||||||
|
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
|
||||||
|
# No msgpack file, only a model.safetensors
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, FLAX_WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(model, new_model))
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
@require_torch
|
||||||
|
def test_safetensors_save_and_load_pt_to_flax(self):
|
||||||
|
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
|
||||||
|
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pt_model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check we have a model.safetensors file
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
self.assertTrue(check_models_equal(model, new_model))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_hub(self):
|
||||||
|
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
|
||||||
|
# Can load from the Flax-formatted checkpoint
|
||||||
|
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only")
|
||||||
|
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_hub_flax_and_pt(self):
|
||||||
|
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
|
||||||
|
# Can load from the PyTorch-formatted checkpoint
|
||||||
|
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only", from_pt=True)
|
||||||
|
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_flax_from_flax(self):
|
||||||
|
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(model, new_model))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
@require_torch
|
||||||
|
def test_safetensors_flax_from_torch(self):
|
||||||
|
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(hub_model, new_model))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_local(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
path = snapshot_download(
|
||||||
|
"hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded", cache_dir=tmp_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should not raise even if there are two types of sharded weights
|
||||||
|
FlaxBertModel.from_pretrained(path)
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_hub(self):
|
||||||
|
# This should not raise even if there are two types of sharded weights
|
||||||
|
# This should discard the safetensors weights in favor of the msgpack sharded weights
|
||||||
|
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")
|
||||||
|
|
|
@ -24,7 +24,7 @@ import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, Repository, delete_repo
|
from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download
|
||||||
from huggingface_hub.file_download import http_get
|
from huggingface_hub.file_download import http_get
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
@ -39,6 +39,7 @@ from transformers.testing_utils import ( # noqa: F401
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_safetensors,
|
require_safetensors,
|
||||||
require_tf,
|
require_tf,
|
||||||
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
||||||
|
@ -496,6 +497,44 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||||
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_tf_from_tf(self):
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_safetensors_tf_from_torch(self):
|
||||||
|
hub_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
for p1, p2 in zip(hub_model.weights, new_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_local(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
path = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", cache_dir=tmp_dir)
|
||||||
|
|
||||||
|
# This should not raise even if there are two types of sharded weights
|
||||||
|
TFBertModel.from_pretrained(path)
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
|
||||||
|
# This should not raise even if there are two types of sharded weights
|
||||||
|
# This should discard the safetensors weights in favor of the .h5 sharded weights
|
||||||
|
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import copy
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
@ -42,7 +42,9 @@ from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
|
require_flax,
|
||||||
require_safetensors,
|
require_safetensors,
|
||||||
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
|
@ -56,7 +58,7 @@ from transformers.utils import (
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
)
|
)
|
||||||
from transformers.utils.import_utils import is_torchdynamo_available
|
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torchdynamo_available
|
||||||
|
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||||
|
@ -66,6 +68,7 @@ from test_module.custom_configuration import CustomConfig, NoSuperInitConfig #
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
from safetensors.torch import save_file as safe_save_file
|
||||||
from test_module.custom_modeling import CustomModel, NoSuperInitModel
|
from test_module.custom_modeling import CustomModel, NoSuperInitModel
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
@ -146,6 +149,13 @@ if is_torch_available():
|
||||||
self.decoder.weight = self.base.linear.weight
|
self.decoder.weight = self.base.linear.weight
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
from transformers import FlaxBertModel
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
from transformers import TFBertModel
|
||||||
|
|
||||||
|
|
||||||
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||||
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||||
|
|
||||||
|
@ -420,13 +430,13 @@ class ModelUtilsTest(TestCasePlus):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_checkpoint_sharding_local(self):
|
def test_checkpoint_sharding_local_bin(self):
|
||||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||||
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
|
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
|
||||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False)
|
||||||
|
|
||||||
# Get each shard file and its size
|
# Get each shard file and its size
|
||||||
shard_to_size = {}
|
shard_to_size = {}
|
||||||
|
@ -472,11 +482,11 @@ class ModelUtilsTest(TestCasePlus):
|
||||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||||
self.assertTrue(torch.allclose(p1, p2))
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
def test_checkpoint_variant_local(self):
|
def test_checkpoint_variant_local_bin(self):
|
||||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
model.save_pretrained(tmp_dir, variant="v2")
|
model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
|
||||||
|
|
||||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||||
|
|
||||||
|
@ -492,11 +502,11 @@ class ModelUtilsTest(TestCasePlus):
|
||||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
self.assertTrue(torch.allclose(p1, p2))
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
def test_checkpoint_variant_local_sharded(self):
|
def test_checkpoint_variant_local_sharded_bin(self):
|
||||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB")
|
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=False)
|
||||||
|
|
||||||
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
|
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
|
||||||
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||||
|
@ -604,18 +614,18 @@ class ModelUtilsTest(TestCasePlus):
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
def test_checkpoint_variant_save_load(self):
|
def test_checkpoint_variant_save_load_bin(self):
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
model = BertModel.from_pretrained(
|
model = BertModel.from_pretrained(
|
||||||
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
|
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
|
||||||
)
|
)
|
||||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||||
|
|
||||||
model.save_pretrained(tmp_dir, variant="v2")
|
model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False)
|
||||||
# saving will create a variant checkpoint
|
# saving will create a variant checkpoint
|
||||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
|
||||||
|
|
||||||
model.save_pretrained(tmp_dir)
|
model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||||
# saving shouldn't delete variant checkpoints
|
# saving shouldn't delete variant checkpoints
|
||||||
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
|
||||||
|
@ -874,7 +884,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||||
def test_base_model_to_head_model_load(self):
|
def test_base_model_to_head_model_load(self):
|
||||||
base_model = BaseModel(PretrainedConfig())
|
base_model = BaseModel(PretrainedConfig())
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
base_model.save_pretrained(tmp_dir)
|
base_model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||||
|
|
||||||
# Can load a base model in a model with head
|
# Can load a base model in a model with head
|
||||||
model = ModelWithHead.from_pretrained(tmp_dir)
|
model = ModelWithHead.from_pretrained(tmp_dir)
|
||||||
|
@ -886,7 +896,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||||
head_state_dict = model.state_dict()
|
head_state_dict = model.state_dict()
|
||||||
base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
|
base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
|
||||||
base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
|
base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
|
||||||
torch.save(base_state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
|
safe_save_file(base_state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, "The state dictionary of the model you are trying to load is corrupted."
|
ValueError, "The state dictionary of the model you are trying to load is corrupted."
|
||||||
|
@ -934,8 +944,8 @@ class ModelUtilsTest(TestCasePlus):
|
||||||
|
|
||||||
# Loading the model with the same class, we do get a warning for unexpected weights
|
# Loading the model with the same class, we do get a warning for unexpected weights
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
state_dict["added_key"] = state_dict["linear.weight"]
|
state_dict["added_key"] = copy.deepcopy(state_dict["linear.weight"])
|
||||||
torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
|
safe_save_file(state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||||
with CaptureLogger(logger) as cl:
|
with CaptureLogger(logger) as cl:
|
||||||
_, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
|
_, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True)
|
||||||
self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
|
self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out)
|
||||||
|
@ -1072,6 +1082,54 @@ class ModelUtilsTest(TestCasePlus):
|
||||||
)
|
)
|
||||||
self.assertEqual(model.generation_config.transformers_version, "foo")
|
self.assertEqual(model.generation_config.transformers_version, "foo")
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_torch_from_torch(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
@require_flax
|
||||||
|
def test_safetensors_torch_from_flax(self):
|
||||||
|
hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||||
|
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
for p1, p2 in zip(hub_model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_torch_from_tf(self):
|
||||||
|
hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
for p1, p2 in zip(hub_model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_torch_from_torch_sharded(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|
|
@ -403,7 +403,7 @@ if is_torch_available():
|
||||||
|
|
||||||
|
|
||||||
class TrainerIntegrationCommon:
|
class TrainerIntegrationCommon:
|
||||||
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=False):
|
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True):
|
||||||
weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
|
weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
|
||||||
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
|
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
|
||||||
if is_pretrained:
|
if is_pretrained:
|
||||||
|
@ -415,7 +415,7 @@ class TrainerIntegrationCommon:
|
||||||
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
|
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
|
||||||
|
|
||||||
def check_best_model_has_been_loaded(
|
def check_best_model_has_been_loaded(
|
||||||
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=False
|
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=True
|
||||||
):
|
):
|
||||||
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
|
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
|
||||||
log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history
|
log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history
|
||||||
|
@ -456,7 +456,7 @@ class TrainerIntegrationCommon:
|
||||||
_ = log1.pop(key, None)
|
_ = log1.pop(key, None)
|
||||||
self.assertEqual(log, log1)
|
self.assertEqual(log, log1)
|
||||||
|
|
||||||
def convert_to_sharded_checkpoint(self, folder, save_safe=False, load_safe=False):
|
def convert_to_sharded_checkpoint(self, folder, save_safe=True, load_safe=True):
|
||||||
# Converts a checkpoint of a regression model to a sharded checkpoint.
|
# Converts a checkpoint of a regression model to a sharded checkpoint.
|
||||||
if load_safe:
|
if load_safe:
|
||||||
loader = safetensors.torch.load_file
|
loader = safetensors.torch.load_file
|
||||||
|
|
|
@ -43,7 +43,6 @@ class CLITest(unittest.TestCase):
|
||||||
shutil.rmtree("/tmp/hf-internal-testing/tiny-random-gptj", ignore_errors=True) # cleans potential past runs
|
shutil.rmtree("/tmp/hf-internal-testing/tiny-random-gptj", ignore_errors=True) # cleans potential past runs
|
||||||
transformers.commands.transformers_cli.main()
|
transformers.commands.transformers_cli.main()
|
||||||
|
|
||||||
# The original repo has no TF weights -- if they exist, they were created by the CLI
|
|
||||||
self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5"))
|
self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5"))
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|
Loading…
Reference in New Issue