Use new huggingface_hub tools for download models (#18438)
* Draft new cached_file * Initial draft for config and model * Small fixes * Fix first batch of tests * Look in cache when internet is down * Fix last tests * Bad black, not fixing all quality errors * Make diff less * Implement change for TF and Flax models * Add tokenizer and feature extractor * For compatibility with main * Add utils to move the cache and auto-do it at first use. * Quality * Deal with empty commit shas * Deal with empty etag * Address review comments
This commit is contained in:
parent
70fa1a8d26
commit
5cd4032368
|
@ -25,25 +25,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||
|
||||
from packaging import version
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
copy_func,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -591,77 +575,43 @@ class PretrainedConfig(PushToHubMixin):
|
|||
if from_pipeline is not None:
|
||||
user_agent["using_pipeline"] = from_pipeline
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
|
||||
pretrained_model_name_or_path
|
||||
):
|
||||
config_file = pretrained_model_name_or_path
|
||||
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
||||
# Soecial case when pretrained_model_name_or_path is a local file
|
||||
resolved_config_file = pretrained_model_name_or_path
|
||||
is_local = True
|
||||
else:
|
||||
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
|
||||
|
||||
if os.path.isdir(os.path.join(pretrained_model_name_or_path, subfolder)):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, subfolder, configuration_file)
|
||||
else:
|
||||
config_file = hf_bucket_url(
|
||||
try:
|
||||
# Load from local folder or from cache or download from model Hub and cache
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=configuration_file,
|
||||
configuration_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder if len(subfolder) > 0 else None,
|
||||
mirror=None,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
except EnvironmentError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||
# the original exception.
|
||||
raise
|
||||
except Exception:
|
||||
# For any other exception, we throw a generic error.
|
||||
raise EnvironmentError(
|
||||
f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
|
||||
f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
|
||||
f" containing a {configuration_file} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_config_file = cached_path(
|
||||
config_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||
"`use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||
"available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
||||
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
||||
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
|
||||
" library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {configuration_file} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load config dict
|
||||
|
@ -671,10 +621,10 @@ class PretrainedConfig(PushToHubMixin):
|
|||
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
||||
)
|
||||
|
||||
if resolved_config_file == config_file:
|
||||
logger.info(f"loading configuration file {config_file}")
|
||||
if is_local:
|
||||
logger.info(f"loading configuration file {resolved_config_file}")
|
||||
else:
|
||||
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
|
||||
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
||||
|
||||
return config_dict, kwargs
|
||||
|
||||
|
|
|
@ -24,23 +24,15 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|||
|
||||
import numpy as np
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .utils import (
|
||||
FEATURE_EXTRACTOR_NAME,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
TensorType,
|
||||
cached_path,
|
||||
cached_file,
|
||||
copy_func,
|
||||
hf_bucket_url,
|
||||
is_flax_available,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
|
@ -388,64 +380,40 @@ class FeatureExtractionMixin(PushToHubMixin):
|
|||
local_files_only = True
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
feature_extractor_file = pretrained_model_name_or_path
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
resolved_feature_extractor_file = pretrained_model_name_or_path
|
||||
is_local = True
|
||||
else:
|
||||
feature_extractor_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=FEATURE_EXTRACTOR_NAME, revision=revision, mirror=None
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_feature_extractor_file = cached_path(
|
||||
feature_extractor_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||
"`use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||
"available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
||||
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
||||
f" containing a {FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run"
|
||||
" the library in offline mode at"
|
||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load it "
|
||||
"from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {FEATURE_EXTRACTOR_NAME} file"
|
||||
)
|
||||
feature_extractor_file = FEATURE_EXTRACTOR_NAME
|
||||
try:
|
||||
# Load from local folder or from cache or download from model Hub and cache
|
||||
resolved_feature_extractor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
feature_extractor_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
)
|
||||
except EnvironmentError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||
# the original exception.
|
||||
raise
|
||||
except Exception:
|
||||
# For any other exception, we throw a generic error.
|
||||
raise EnvironmentError(
|
||||
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
||||
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||
f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load feature_extractor dict
|
||||
|
@ -458,12 +426,11 @@ class FeatureExtractionMixin(PushToHubMixin):
|
|||
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
|
||||
)
|
||||
|
||||
if resolved_feature_extractor_file == feature_extractor_file:
|
||||
logger.info(f"loading feature extractor configuration file {feature_extractor_file}")
|
||||
if is_local:
|
||||
logger.info(f"loading configuration file {resolved_feature_extractor_file}")
|
||||
else:
|
||||
logger.info(
|
||||
f"loading feature extractor configuration file {feature_extractor_file} from cache at"
|
||||
f" {resolved_feature_extractor_file}"
|
||||
f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
|
||||
)
|
||||
|
||||
return feature_extractor_dict, kwargs
|
||||
|
|
|
@ -32,7 +32,6 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
|
|||
from flax.serialization import from_bytes, to_bytes
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax.random import PRNGKey
|
||||
from requests import HTTPError
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
|
@ -41,20 +40,14 @@ from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_d
|
|||
from .utils import (
|
||||
FLAX_WEIGHTS_INDEX_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
WEIGHTS_NAME,
|
||||
EntryNotFoundError,
|
||||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
cached_path,
|
||||
cached_file,
|
||||
copy_func,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
@ -557,6 +550,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||
|
@ -598,6 +594,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
_do_init = kwargs.pop("_do_init", True)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
|
||||
if trust_remote_code is True:
|
||||
logger.warning(
|
||||
|
@ -642,6 +639,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
|
||||
# Load model
|
||||
if pretrained_model_name_or_path is not None:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
|
@ -665,65 +664,44 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||
f"{pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
is_local = True
|
||||
else:
|
||||
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=filename,
|
||||
revision=revision,
|
||||
)
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
cached_file_kwargs = dict(
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
|
||||
try:
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
if filename == FLAX_WEIGHTS_NAME:
|
||||
try:
|
||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
|
||||
# result when internet is up, the repo and revision exist, but the file does not.
|
||||
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
|
||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=FLAX_WEIGHTS_INDEX_NAME,
|
||||
revision=revision,
|
||||
resolved_archive_file = cached_file(
|
||||
pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||
)
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
is_sharded = True
|
||||
except EntryNotFoundError:
|
||||
has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
if resolved_archive_file is None:
|
||||
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
||||
# message.
|
||||
has_file_kwargs = {
|
||||
"revision": revision,
|
||||
"proxies": proxies,
|
||||
"use_auth_token": use_auth_token,
|
||||
}
|
||||
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
|
@ -735,35 +713,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
)
|
||||
else:
|
||||
except EnvironmentError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
||||
# to the original exception.
|
||||
raise
|
||||
except Exception:
|
||||
# For any other exception, we throw a generic error.
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
||||
f"{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
|
||||
" internet connection or see how to run the library in offline mode at"
|
||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
if is_local:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
resolved_archive_file = archive_file
|
||||
else:
|
||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
||||
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
|
||||
else:
|
||||
resolved_archive_file = None
|
||||
|
||||
|
|
|
@ -37,7 +37,6 @@ from tensorflow.python.keras.saving import hdf5_format
|
|||
|
||||
from huggingface_hub import Repository, list_repo_files
|
||||
from keras.saving.hdf5_format import save_attributes_to_hdf5_group
|
||||
from requests import HTTPError
|
||||
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||
|
||||
from . import DataCollatorWithPadding, DefaultDataCollator
|
||||
|
@ -48,22 +47,16 @@ from .generation_tf_utils import TFGenerationMixin
|
|||
from .tf_utils import shape_list
|
||||
from .utils import (
|
||||
DUMMY_INPUTS,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
EntryNotFoundError,
|
||||
ModelOutput,
|
||||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
cached_file,
|
||||
find_labels,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
requires_backends,
|
||||
working_or_temp_dir,
|
||||
|
@ -2112,6 +2105,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||
|
@ -2164,6 +2160,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
|
||||
if trust_remote_code is True:
|
||||
logger.warning(
|
||||
|
@ -2202,9 +2199,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# index of the files.
|
||||
is_sharded = False
|
||||
sharded_metadata = None
|
||||
# Load model
|
||||
if pretrained_model_name_or_path is not None:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint in priority if from_pt
|
||||
|
@ -2232,68 +2230,43 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||
f"{pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
is_local = True
|
||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
is_local = True
|
||||
else:
|
||||
# set correct filename
|
||||
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=filename,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
cached_file_kwargs = dict(
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
if filename == TF2_WEIGHTS_NAME:
|
||||
try:
|
||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
|
||||
# result when internet is up, the repo and revision exist, but the file does not.
|
||||
if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
|
||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=TF2_WEIGHTS_INDEX_NAME,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
resolved_archive_file = cached_file(
|
||||
pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||
)
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
is_sharded = True
|
||||
except EntryNotFoundError:
|
||||
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
if resolved_archive_file is None:
|
||||
# Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
|
||||
# message.
|
||||
has_file_kwargs = {
|
||||
"revision": revision,
|
||||
|
@ -2312,42 +2285,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
||||
f"{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your internet"
|
||||
" connection or see how to run the library in offline mode at"
|
||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
except EnvironmentError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
||||
# to the original exception.
|
||||
raise
|
||||
except Exception:
|
||||
# For any other exception, we throw a generic error.
|
||||
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
)
|
||||
if is_local:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
resolved_archive_file = archive_file
|
||||
else:
|
||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
||||
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
|
||||
else:
|
||||
resolved_archive_file = None
|
||||
|
||||
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
||||
if is_sharded:
|
||||
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
||||
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
||||
resolved_archive_file, _ = get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
resolved_archive_file,
|
||||
cache_dir=cache_dir,
|
||||
|
|
|
@ -31,7 +31,6 @@ from packaging import version
|
|||
from torch import Tensor, device, nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from requests import HTTPError
|
||||
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||
from transformers.utils.import_utils import is_sagemaker_mp_enabled
|
||||
|
||||
|
@ -51,24 +50,18 @@ from .pytorch_utils import ( # noqa: F401
|
|||
from .utils import (
|
||||
DUMMY_INPUTS,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
TF2_WEIGHTS_NAME,
|
||||
TF_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
ContextManagers,
|
||||
EntryNotFoundError,
|
||||
ModelOutput,
|
||||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
cached_file,
|
||||
copy_func,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_accelerate_available,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
@ -1868,7 +1861,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
|
||||
if pretrained_model_name_or_path is not None:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if is_local:
|
||||
if from_tf and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
||||
):
|
||||
|
@ -1911,10 +1905,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
|
||||
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
|
||||
pretrained_model_name_or_path
|
||||
):
|
||||
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
is_local = True
|
||||
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
|
||||
if not from_tf:
|
||||
raise ValueError(
|
||||
|
@ -1922,6 +1915,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
"from_tf to True to load from this checkpoint."
|
||||
)
|
||||
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
|
||||
is_local = True
|
||||
else:
|
||||
# set correct filename
|
||||
if from_tf:
|
||||
|
@ -1931,63 +1925,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
else:
|
||||
filename = WEIGHTS_NAME
|
||||
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=filename,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
subfolder=subfolder if len(subfolder) > 0 else None,
|
||||
)
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
cached_file_kwargs = dict(
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
if filename == WEIGHTS_NAME:
|
||||
try:
|
||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
|
||||
# result when internet is up, the repo and revision exist, but the file does not.
|
||||
if resolved_archive_file is None and filename == WEIGHTS_NAME:
|
||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_INDEX_NAME,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
subfolder=subfolder if len(subfolder) > 0 else None,
|
||||
resolved_archive_file = cached_file(
|
||||
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||
)
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
is_sharded = True
|
||||
except EntryNotFoundError:
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
if resolved_archive_file is None:
|
||||
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
||||
# message.
|
||||
has_file_kwargs = {
|
||||
|
@ -2013,42 +1976,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
|
||||
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
||||
)
|
||||
else:
|
||||
except EnvironmentError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
||||
# to the original exception.
|
||||
raise
|
||||
except Exception:
|
||||
# For any other exception, we throw a generic error.
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
|
||||
f" {FLAX_WEIGHTS_NAME}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
||||
f"{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
|
||||
f" {FLAX_WEIGHTS_NAME}.\nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or "
|
||||
f"{FLAX_WEIGHTS_NAME}."
|
||||
)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
if is_local:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
resolved_archive_file = archive_file
|
||||
else:
|
||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
||||
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
|
||||
else:
|
||||
resolved_archive_file = None
|
||||
|
||||
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
||||
if is_sharded:
|
||||
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
||||
# rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
||||
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
resolved_archive_file,
|
||||
|
|
|
@ -35,21 +35,16 @@ from packaging import version
|
|||
from . import __version__
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .utils import (
|
||||
EntryNotFoundError,
|
||||
ExplicitEnum,
|
||||
PaddingStrategy,
|
||||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
TensorType,
|
||||
add_end_docstrings,
|
||||
cached_path,
|
||||
cached_file,
|
||||
copy_func,
|
||||
get_file_from_repo,
|
||||
hf_bucket_url,
|
||||
is_flax_available,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
|
@ -1669,7 +1664,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
vocab_files = {}
|
||||
init_configuration = {}
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
if len(cls.vocab_files_names) > 1:
|
||||
raise ValueError(
|
||||
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
|
||||
|
@ -1689,9 +1685,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
|
||||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
||||
}
|
||||
vocab_files_target = {**cls.vocab_files_names, **additional_files_names}
|
||||
vocab_files = {**cls.vocab_files_names, **additional_files_names}
|
||||
|
||||
if "tokenizer_file" in vocab_files_target:
|
||||
if "tokenizer_file" in vocab_files:
|
||||
# Try to get the tokenizer config to see if there are versioned tokenizer files.
|
||||
fast_tokenizer_file = FULL_TOKENIZER_FILE
|
||||
resolved_config_file = get_file_from_repo(
|
||||
|
@ -1704,80 +1700,38 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
if resolved_config_file is not None:
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
tokenizer_config = json.load(reader)
|
||||
if "fast_tokenizer_files" in tokenizer_config:
|
||||
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
|
||||
vocab_files_target["tokenizer_file"] = fast_tokenizer_file
|
||||
|
||||
# Look for the tokenizer files
|
||||
for file_id, file_name in vocab_files_target.items():
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if subfolder is not None:
|
||||
full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)
|
||||
else:
|
||||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
||||
if not os.path.exists(full_file_name):
|
||||
logger.info(f"Didn't find file {full_file_name}. We won't load it.")
|
||||
full_file_name = None
|
||||
else:
|
||||
full_file_name = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=file_name,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
mirror=None,
|
||||
)
|
||||
|
||||
vocab_files[file_id] = full_file_name
|
||||
vocab_files["tokenizer_file"] = fast_tokenizer_file
|
||||
|
||||
# Get files from url, cache, or disk depending on the case
|
||||
resolved_vocab_files = {}
|
||||
unresolved_files = []
|
||||
for file_id, file_path in vocab_files.items():
|
||||
print(file_id, file_path)
|
||||
if file_path is None:
|
||||
resolved_vocab_files[file_id] = None
|
||||
else:
|
||||
try:
|
||||
resolved_vocab_files[file_id] = cached_path(
|
||||
file_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except FileNotFoundError as error:
|
||||
if local_files_only:
|
||||
unresolved_files.append(file_id)
|
||||
else:
|
||||
raise error
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
|
||||
"pass a token having permission to this repo with `use_auth_token` or log in with "
|
||||
"`huggingface-cli login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||
"for this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
|
||||
resolved_vocab_files[file_id] = None
|
||||
|
||||
except ValueError:
|
||||
logger.debug(f"Connection problem to access {file_path} and it wasn't found in the cache.")
|
||||
resolved_vocab_files[file_id] = None
|
||||
resolved_vocab_files[file_id] = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
file_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
|
||||
if len(unresolved_files) > 0:
|
||||
logger.info(
|
||||
|
@ -1797,7 +1751,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
if file_id not in resolved_vocab_files:
|
||||
continue
|
||||
|
||||
if file_path == resolved_vocab_files[file_id]:
|
||||
if is_local:
|
||||
logger.info(f"loading file {file_path}")
|
||||
else:
|
||||
logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
|
||||
|
|
|
@ -60,6 +60,7 @@ from .hub import (
|
|||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_file,
|
||||
cached_path,
|
||||
default_cache_path,
|
||||
define_sagemaker_information,
|
||||
|
@ -76,6 +77,7 @@ from .hub import (
|
|||
is_local_clone,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
move_cache,
|
||||
send_example_telemetry,
|
||||
url_to_filename,
|
||||
)
|
||||
|
|
|
@ -19,11 +19,13 @@ import fnmatch
|
|||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import traceback
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
@ -34,9 +36,20 @@ from urllib.parse import urlparse
|
|||
from uuid import uuid4
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import huggingface_hub
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
from huggingface_hub import CommitOperationAdd, HfFolder, create_commit, create_repo, list_repo_files, whoami
|
||||
from huggingface_hub import (
|
||||
CommitOperationAdd,
|
||||
HfFolder,
|
||||
create_commit,
|
||||
create_repo,
|
||||
hf_hub_download,
|
||||
list_repo_files,
|
||||
whoami,
|
||||
)
|
||||
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests.exceptions import HTTPError
|
||||
from requests.models import Response
|
||||
from transformers.utils.logging import tqdm
|
||||
|
@ -385,21 +398,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
|||
return ua
|
||||
|
||||
|
||||
class RepositoryNotFoundError(HTTPError):
|
||||
"""
|
||||
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
|
||||
not have access to.
|
||||
"""
|
||||
|
||||
|
||||
class EntryNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
|
||||
|
||||
|
||||
class RevisionNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
|
||||
|
||||
|
||||
def _raise_for_status(response: Response):
|
||||
"""
|
||||
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
|
||||
|
@ -628,6 +626,213 @@ def get_from_cache(
|
|||
return cache_path
|
||||
|
||||
|
||||
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
|
||||
"""
|
||||
Explores the cache to return the latest cached file for a given revision.
|
||||
"""
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
|
||||
model_id = repo_id.replace("/", "--")
|
||||
model_cache = os.path.join(cache_dir, f"models--{model_id}")
|
||||
if not os.path.isdir(model_cache):
|
||||
# No cache for this model
|
||||
return None
|
||||
|
||||
# Resolve refs (for instance to convert main to the associated commit sha)
|
||||
cached_refs = os.listdir(os.path.join(model_cache, "refs"))
|
||||
if revision in cached_refs:
|
||||
with open(os.path.join(model_cache, "refs", revision)) as f:
|
||||
revision = f.read()
|
||||
|
||||
cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
|
||||
if revision not in cached_shas:
|
||||
# No cache for this revision and we won't try to return a random revision
|
||||
return None
|
||||
|
||||
cached_file = os.path.join(model_cache, "snapshots", revision, filename)
|
||||
return cached_file if os.path.isfile(cached_file) else None
|
||||
|
||||
|
||||
# If huggingface_hub changes the class of error for this to FileNotFoundError, we will be able to avoid that in the
|
||||
# future.
|
||||
LOCAL_FILES_ONLY_HF_ERROR = (
|
||||
"Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co "
|
||||
"look-ups and downloads online, set 'local_files_only' to False."
|
||||
)
|
||||
|
||||
|
||||
# In the future, this ugly contextmanager can be removed when huggingface_hub as a released version where we can
|
||||
# activate/deactivate progress bars.
|
||||
@contextmanager
|
||||
def _patch_hf_hub_tqdm():
|
||||
"""
|
||||
A context manager to make huggingface hub use the tqdm version of Transformers (which is controlled by some utils)
|
||||
in logging.
|
||||
"""
|
||||
old_tqdm = huggingface_hub.file_download.tqdm
|
||||
huggingface_hub.file_download.tqdm = tqdm
|
||||
yield
|
||||
huggingface_hub.file_download.tqdm = old_tqdm
|
||||
|
||||
|
||||
def cached_file(
|
||||
path_or_repo_id: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = False,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
subfolder: str = "",
|
||||
user_agent: Optional[Union[str, Dict[str, str]]] = None,
|
||||
_raise_exceptions_for_missing_entries=True,
|
||||
_raise_exceptions_for_connection_errors=True,
|
||||
):
|
||||
"""
|
||||
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
||||
|
||||
Args:
|
||||
path_or_repo_id (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the *model id* of a model repo on huggingface.co.
|
||||
- a path to a *directory* potentially containing the file.
|
||||
filename (`str`):
|
||||
The name of the file to locate in `path_or_repo`.
|
||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
||||
cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||
exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `use_auth_token=True` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Download a model weight from the Hub and cache it.
|
||||
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
|
||||
```"""
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
if subfolder is None:
|
||||
subfolder = ""
|
||||
|
||||
path_or_repo_id = str(path_or_repo_id)
|
||||
full_filename = os.path.join(subfolder, filename)
|
||||
if os.path.isdir(path_or_repo_id):
|
||||
resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
|
||||
if not os.path.isfile(resolved_file):
|
||||
if _raise_exceptions_for_missing_entries:
|
||||
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
|
||||
else:
|
||||
return None
|
||||
return resolved_file
|
||||
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
user_agent = http_user_agent(user_agent)
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
with _patch_hf_hub_tqdm():
|
||||
resolved_file = hf_hub_download(
|
||||
path_or_repo_id,
|
||||
filename,
|
||||
subfolder=None if len(subfolder) == 0 else subfolder,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
use_auth_token=use_auth_token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
|
||||
"pass a token having permission to this repo with `use_auth_token` or log in with "
|
||||
"`huggingface-cli login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||
"for this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
if not _raise_exceptions_for_missing_entries:
|
||||
return None
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
raise EnvironmentError(
|
||||
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
|
||||
f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
|
||||
)
|
||||
except HTTPError as err:
|
||||
# First we try to see if we have a cached version (not up to date):
|
||||
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
|
||||
if resolved_file is not None:
|
||||
return resolved_file
|
||||
if not _raise_exceptions_for_connection_errors:
|
||||
return None
|
||||
|
||||
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
|
||||
except ValueError as err:
|
||||
# HuggingFace Hub returns a ValueError for a missing file when local_files_only=True we need to catch it here
|
||||
# This could be caught above along in `EntryNotFoundError` if hf_hub sent a different error message here
|
||||
if LOCAL_FILES_ONLY_HF_ERROR in err.args[0] and local_files_only and not _raise_exceptions_for_missing_entries:
|
||||
return None
|
||||
|
||||
# Otherwise we try to see if we have a cached version (not up to date):
|
||||
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
|
||||
if resolved_file is not None:
|
||||
return resolved_file
|
||||
if not _raise_exceptions_for_connection_errors:
|
||||
return None
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
|
||||
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
|
||||
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
|
||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
|
||||
return resolved_file
|
||||
|
||||
|
||||
def get_file_from_repo(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
|
@ -638,6 +843,7 @@ def get_file_from_repo(
|
|||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
subfolder: str = "",
|
||||
):
|
||||
"""
|
||||
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
||||
|
@ -670,6 +876,9 @@ def get_file_from_repo(
|
|||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
|
||||
<Tip>
|
||||
|
||||
|
@ -689,47 +898,20 @@ def get_file_from_repo(
|
|||
# This model does not have a tokenizer config so the result will be None.
|
||||
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
|
||||
```"""
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
path_or_repo = str(path_or_repo)
|
||||
if os.path.isdir(path_or_repo):
|
||||
resolved_file = os.path.join(path_or_repo, filename)
|
||||
return resolved_file if os.path.isfile(resolved_file) else None
|
||||
else:
|
||||
resolved_file = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=None)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_file = cached_path(
|
||||
resolved_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{path_or_repo} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
|
||||
"pass a token having permission to this repo with `use_auth_token` or log in with "
|
||||
"`huggingface-cli login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||
"for this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{path_or_repo}' for available revisions."
|
||||
)
|
||||
except EnvironmentError:
|
||||
# The repo and revision exist, but the file does not or there was a connection error fetching it.
|
||||
return None
|
||||
|
||||
return resolved_file
|
||||
return cached_file(
|
||||
path_or_repo_id=path_or_repo,
|
||||
filename=filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
|
||||
|
||||
def has_file(
|
||||
|
@ -766,7 +948,7 @@ def has_file(
|
|||
|
||||
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
|
||||
try:
|
||||
_raise_for_status(r)
|
||||
huggingface_hub.utils._errors._raise_for_status(r)
|
||||
return True
|
||||
except RepositoryNotFoundError as e:
|
||||
logger.error(e)
|
||||
|
@ -1196,3 +1378,183 @@ def get_checkpoint_shard_files(
|
|||
cached_filenames.append(cached_filename)
|
||||
|
||||
return cached_filenames, sharded_metadata
|
||||
|
||||
|
||||
# All what is below is for conversion between old cache format and new cache format.
|
||||
|
||||
|
||||
def get_all_cached_files(cache_dir=None):
|
||||
"""
|
||||
Returns a list for all files cached with appropriate metadata.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
else:
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
cached_files = []
|
||||
for file in os.listdir(cache_dir):
|
||||
meta_path = os.path.join(cache_dir, f"{file}.json")
|
||||
if not os.path.isfile(meta_path):
|
||||
continue
|
||||
|
||||
with open(meta_path, encoding="utf-8") as meta_file:
|
||||
metadata = json.load(meta_file)
|
||||
url = metadata["url"]
|
||||
etag = metadata["etag"].replace('"', "")
|
||||
cached_files.append({"file": file, "url": url, "etag": etag})
|
||||
|
||||
return cached_files
|
||||
|
||||
|
||||
def get_hub_metadata(url, token=None):
|
||||
"""
|
||||
Returns the commit hash and associated etag for a given url.
|
||||
"""
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
headers = {"user-agent": http_user_agent()}
|
||||
headers["authorization"] = f"Bearer {token}"
|
||||
|
||||
r = huggingface_hub.file_download._request_with_retry(
|
||||
method="HEAD", url=url, headers=headers, allow_redirects=False
|
||||
)
|
||||
huggingface_hub.file_download._raise_for_status(r)
|
||||
commit_hash = r.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT)
|
||||
etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")
|
||||
if etag is not None:
|
||||
etag = huggingface_hub.file_download._normalize_etag(etag)
|
||||
return etag, commit_hash
|
||||
|
||||
|
||||
def extract_info_from_url(url):
|
||||
"""
|
||||
Extract repo_name, revision and filename from an url.
|
||||
"""
|
||||
search = re.search(r"^https://huggingface\.co/(.*)/resolve/([^/]*)/(.*)$", url)
|
||||
if search is None:
|
||||
return None
|
||||
repo, revision, filename = search.groups()
|
||||
cache_repo = "--".join(["models"] + repo.split("/"))
|
||||
return {"repo": cache_repo, "revision": revision, "filename": filename}
|
||||
|
||||
|
||||
def clean_files_for(file):
|
||||
"""
|
||||
Remove, if they exist, file, file.json and file.lock
|
||||
"""
|
||||
for f in [file, f"{file}.json", f"{file}.lock"]:
|
||||
if os.path.isfile(f):
|
||||
os.remove(f)
|
||||
|
||||
|
||||
def move_to_new_cache(file, repo, filename, revision, etag, commit_hash):
|
||||
"""
|
||||
Move file to repo following the new huggingface hub cache organization.
|
||||
"""
|
||||
os.makedirs(repo, exist_ok=True)
|
||||
|
||||
# refs
|
||||
os.makedirs(os.path.join(repo, "refs"), exist_ok=True)
|
||||
if revision != commit_hash:
|
||||
ref_path = os.path.join(repo, "refs", revision)
|
||||
with open(ref_path, "w") as f:
|
||||
f.write(commit_hash)
|
||||
|
||||
# blobs
|
||||
os.makedirs(os.path.join(repo, "blobs"), exist_ok=True)
|
||||
# TODO: replace copy by move when all works well.
|
||||
blob_path = os.path.join(repo, "blobs", etag)
|
||||
shutil.move(file, blob_path)
|
||||
|
||||
# snapshots
|
||||
os.makedirs(os.path.join(repo, "snapshots"), exist_ok=True)
|
||||
os.makedirs(os.path.join(repo, "snapshots", commit_hash), exist_ok=True)
|
||||
pointer_path = os.path.join(repo, "snapshots", commit_hash, filename)
|
||||
huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
|
||||
clean_files_for(file)
|
||||
|
||||
|
||||
def move_cache(cache_dir=None, token=None):
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
cached_files = get_all_cached_files(cache_dir=cache_dir)
|
||||
print(f"Moving {len(cached_files)} files to the new cache system")
|
||||
|
||||
hub_metadata = {}
|
||||
for file_info in tqdm(cached_files):
|
||||
url = file_info.pop("url")
|
||||
if url not in hub_metadata:
|
||||
try:
|
||||
hub_metadata[url] = get_hub_metadata(url, token=token)
|
||||
except requests.HTTPError:
|
||||
continue
|
||||
|
||||
etag, commit_hash = hub_metadata[url]
|
||||
if etag is None or commit_hash is None:
|
||||
continue
|
||||
|
||||
if file_info["etag"] != etag:
|
||||
# Cached file is not up to date, we just throw it as a new version will be downloaded anyway.
|
||||
clean_files_for(os.path.join(cache_dir, file_info["file"]))
|
||||
continue
|
||||
|
||||
url_info = extract_info_from_url(url)
|
||||
if url_info is None:
|
||||
# Not a file from huggingface.co
|
||||
continue
|
||||
|
||||
repo = os.path.join(cache_dir, url_info["repo"])
|
||||
move_to_new_cache(
|
||||
file=os.path.join(cache_dir, file_info["file"]),
|
||||
repo=repo,
|
||||
filename=url_info["filename"],
|
||||
revision=url_info["revision"],
|
||||
etag=etag,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
|
||||
cache_version_file = os.path.join(TRANSFORMERS_CACHE, "version.txt")
|
||||
if not os.path.isfile(cache_version_file):
|
||||
cache_version = 0
|
||||
else:
|
||||
with open(cache_version_file) as f:
|
||||
cache_version = int(f.read())
|
||||
|
||||
|
||||
if cache_version < 1:
|
||||
if is_offline_mode():
|
||||
logger.warn(
|
||||
"You are offline and the cache for model files in Transformers v4.22.0 has been updated while your local "
|
||||
"cache seems to be the one of a previous version. It is very likely that all your calls to any "
|
||||
"`from_pretrained()` method will fail. Remove the offline mode and enable internet connection to have "
|
||||
"your cache be updated automatically, then you can go back to offline mode."
|
||||
)
|
||||
else:
|
||||
logger.warn(
|
||||
"The cache for model files in Transformers v4.22.0 has been udpated. Migrating your old cache. This is a "
|
||||
"one-time only operation. You can interrupt this and resume the migration later on by calling "
|
||||
"`transformers.utils.move_cache()`."
|
||||
)
|
||||
try:
|
||||
move_cache()
|
||||
except Exception as e:
|
||||
trace = "\n".join(traceback.format_tb(e.__traceback__))
|
||||
logger.error(
|
||||
f"There was a problem when trying to move your cache:\n\n{trace}\n\nPlease file an issue at "
|
||||
"https://github.com/huggingface/transformers/issues/new/choose and copy paste this whole message and we "
|
||||
"will do our best to help."
|
||||
)
|
||||
|
||||
try:
|
||||
os.makedirs(TRANSFORMERS_CACHE, exist_ok=True)
|
||||
with open(cache_version_file, "w") as f:
|
||||
f.write("1")
|
||||
except Exception:
|
||||
logger.warn(
|
||||
f"There was a problem when trying to write in your cache folder ({TRANSFORMERS_CACHE}). You should set "
|
||||
"the environment variable TRANSFORMERS_CACHE to a writable directory."
|
||||
)
|
||||
|
|
|
@ -345,14 +345,14 @@ class ConfigTestUtils(unittest.TestCase):
|
|||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = []
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
|
||||
with mock.patch("requests.request", return_value=response_mock) as mock_head:
|
||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
|
|
@ -170,13 +170,13 @@ class FeatureExtractorUtilTester(unittest.TestCase):
|
|||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = []
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
|
||||
with mock.patch("requests.request", return_value=response_mock) as mock_head:
|
||||
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
|
|
@ -2925,14 +2925,14 @@ class ModelUtilsTest(TestCasePlus):
|
|||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = []
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
|
||||
with mock.patch("requests.request", return_value=response_mock) as mock_head:
|
||||
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
|
|
@ -1922,14 +1922,14 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = []
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
|
||||
with mock.patch("requests.request", return_value=response_mock) as mock_head:
|
||||
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
|
|
@ -3829,14 +3829,14 @@ class TokenizerUtilTester(unittest.TestCase):
|
|||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = []
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
|
||||
with mock.patch("requests.request", return_value=response_mock) as mock_head:
|
||||
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
|
Loading…
Reference in New Issue