Use new huggingface_hub tools for download models (#18438)

* Draft new cached_file

* Initial draft for config and model

* Small fixes

* Fix first batch of tests

* Look in cache when internet is down

* Fix last tests

* Bad black, not fixing all quality errors

* Make diff less

* Implement change for TF and Flax models

* Add tokenizer and feature extractor

* For compatibility with main

* Add utils to move the cache and auto-do it at first use.

* Quality

* Deal with empty commit shas

* Deal with empty etag

* Address review comments
This commit is contained in:
Sylvain Gugger 2022-08-05 10:12:40 -04:00 committed by GitHub
parent 70fa1a8d26
commit 5cd4032368
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 673 additions and 556 deletions

View File

@ -25,25 +25,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from packaging import version
from requests import HTTPError
from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import (
CONFIG_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path,
copy_func,
hf_bucket_url,
is_offline_mode,
is_remote_url,
is_torch_available,
logging,
)
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
logger = logging.get_logger(__name__)
@ -591,77 +575,43 @@ class PretrainedConfig(PushToHubMixin):
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
pretrained_model_name_or_path
):
config_file = pretrained_model_name_or_path
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
# Soecial case when pretrained_model_name_or_path is a local file
resolved_config_file = pretrained_model_name_or_path
is_local = True
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
if os.path.isdir(os.path.join(pretrained_model_name_or_path, subfolder)):
config_file = os.path.join(pretrained_model_name_or_path, subfolder, configuration_file)
else:
config_file = hf_bucket_url(
try:
# Load from local folder or from cache or download from model Hub and cache
resolved_config_file = cached_file(
pretrained_model_name_or_path,
filename=configuration_file,
configuration_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder if len(subfolder) > 0 else None,
mirror=None,
subfolder=subfolder,
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
f" containing a {configuration_file} file"
)
try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {configuration_file} file"
)
try:
# Load config dict
@ -671,10 +621,10 @@ class PretrainedConfig(PushToHubMixin):
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
)
if resolved_config_file == config_file:
logger.info(f"loading configuration file {config_file}")
if is_local:
logger.info(f"loading configuration file {resolved_config_file}")
else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
return config_dict, kwargs

View File

@ -24,23 +24,15 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import numpy as np
from requests import HTTPError
from .dynamic_module_utils import custom_object_save
from .utils import (
FEATURE_EXTRACTOR_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
TensorType,
cached_path,
cached_file,
copy_func,
hf_bucket_url,
is_flax_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_torch_available,
logging,
@ -388,64 +380,40 @@ class FeatureExtractionMixin(PushToHubMixin):
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
feature_extractor_file = pretrained_model_name_or_path
if os.path.isfile(pretrained_model_name_or_path):
resolved_feature_extractor_file = pretrained_model_name_or_path
is_local = True
else:
feature_extractor_file = hf_bucket_url(
pretrained_model_name_or_path, filename=FEATURE_EXTRACTOR_NAME, revision=revision, mirror=None
)
try:
# Load from URL or cache if already cached
resolved_feature_extractor_file = cached_path(
feature_extractor_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run"
" the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load it "
"from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {FEATURE_EXTRACTOR_NAME} file"
)
feature_extractor_file = FEATURE_EXTRACTOR_NAME
try:
# Load from local folder or from cache or download from model Hub and cache
resolved_feature_extractor_file = cached_file(
pretrained_model_name_or_path,
feature_extractor_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
)
try:
# Load feature_extractor dict
@ -458,12 +426,11 @@ class FeatureExtractionMixin(PushToHubMixin):
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
)
if resolved_feature_extractor_file == feature_extractor_file:
logger.info(f"loading feature extractor configuration file {feature_extractor_file}")
if is_local:
logger.info(f"loading configuration file {resolved_feature_extractor_file}")
else:
logger.info(
f"loading feature extractor configuration file {feature_extractor_file} from cache at"
f" {resolved_feature_extractor_file}"
f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
)
return feature_extractor_dict, kwargs

View File

@ -32,7 +32,6 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from requests import HTTPError
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
@ -41,20 +40,14 @@ from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_d
from .utils import (
FLAX_WEIGHTS_INDEX_NAME,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
add_code_sample_docstrings,
add_start_docstrings_to_model_forward,
cached_path,
cached_file,
copy_func,
has_file,
hf_bucket_url,
is_offline_mode,
is_remote_url,
logging,
replace_return_docstrings,
)
@ -557,6 +550,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
@ -598,6 +594,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_do_init = kwargs.pop("_do_init", True)
subfolder = kwargs.pop("subfolder", "")
if trust_remote_code is True:
logger.warning(
@ -642,6 +639,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# Load model
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
@ -665,65 +664,44 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}."
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
is_local = True
else:
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=filename,
revision=revision,
)
try:
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
if filename == FLAX_WEIGHTS_NAME:
try:
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=FLAX_WEIGHTS_INDEX_NAME,
revision=revision,
resolved_archive_file = cached_file(
pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
is_sharded = True
except EntryNotFoundError:
has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
@ -735,35 +713,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
else:
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
" internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
if resolved_archive_file == archive_file:
if is_local:
logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None

View File

@ -37,7 +37,6 @@ from tensorflow.python.keras.saving import hdf5_format
from huggingface_hub import Repository, list_repo_files
from keras.saving.hdf5_format import save_attributes_to_hdf5_group
from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from . import DataCollatorWithPadding, DefaultDataCollator
@ -48,22 +47,16 @@ from .generation_tf_utils import TFGenerationMixin
from .tf_utils import shape_list
from .utils import (
DUMMY_INPUTS,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
EntryNotFoundError,
ModelOutput,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path,
cached_file,
find_labels,
has_file,
hf_bucket_url,
is_offline_mode,
is_remote_url,
logging,
requires_backends,
working_or_temp_dir,
@ -2112,6 +2105,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
@ -2164,6 +2160,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", "")
if trust_remote_code is True:
logger.warning(
@ -2202,9 +2199,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
sharded_metadata = None
# Load model
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint in priority if from_pt
@ -2232,68 +2230,43 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}."
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
is_local = True
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index"
is_local = True
else:
# set correct filename
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=filename,
revision=revision,
mirror=mirror,
)
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
try:
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
if filename == TF2_WEIGHTS_NAME:
try:
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=TF2_WEIGHTS_INDEX_NAME,
revision=revision,
mirror=mirror,
resolved_archive_file = cached_file(
pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
is_sharded = True
except EntryNotFoundError:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
@ -2312,42 +2285,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your internet"
" connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
if resolved_archive_file == archive_file:
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
if is_local:
logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
resolved_archive_file, _ = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
cache_dir=cache_dir,

View File

@ -31,7 +31,6 @@ from packaging import version
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss
from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from transformers.utils.import_utils import is_sagemaker_mp_enabled
@ -51,24 +50,18 @@ from .pytorch_utils import ( # noqa: F401
from .utils import (
DUMMY_INPUTS,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
ContextManagers,
EntryNotFoundError,
ModelOutput,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path,
cached_file,
copy_func,
has_file,
hf_bucket_url,
is_accelerate_available,
is_offline_mode,
is_remote_url,
logging,
replace_return_docstrings,
)
@ -1868,7 +1861,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
if from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
):
@ -1911,10 +1905,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
)
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
pretrained_model_name_or_path
):
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
archive_file = pretrained_model_name_or_path
is_local = True
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
if not from_tf:
raise ValueError(
@ -1922,6 +1915,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"from_tf to True to load from this checkpoint."
)
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
is_local = True
else:
# set correct filename
if from_tf:
@ -1931,63 +1925,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
filename = WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=filename,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
)
try:
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
if filename == WEIGHTS_NAME:
try:
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=WEIGHTS_INDEX_NAME,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
is_sharded = True
except EntryNotFoundError:
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
@ -2013,42 +1976,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
)
else:
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
f" {FLAX_WEIGHTS_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
f" {FLAX_WEIGHTS_NAME}.\nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or "
f"{FLAX_WEIGHTS_NAME}."
)
if resolved_archive_file == archive_file:
if is_local:
logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
# rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,

View File

@ -35,21 +35,16 @@ from packaging import version
from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import (
EntryNotFoundError,
ExplicitEnum,
PaddingStrategy,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
TensorType,
add_end_docstrings,
cached_path,
cached_file,
copy_func,
get_file_from_repo,
hf_bucket_url,
is_flax_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
@ -1669,7 +1664,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
vocab_files = {}
init_configuration = {}
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
if len(cls.vocab_files_names) > 1:
raise ValueError(
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
@ -1689,9 +1685,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
}
vocab_files_target = {**cls.vocab_files_names, **additional_files_names}
vocab_files = {**cls.vocab_files_names, **additional_files_names}
if "tokenizer_file" in vocab_files_target:
if "tokenizer_file" in vocab_files:
# Try to get the tokenizer config to see if there are versioned tokenizer files.
fast_tokenizer_file = FULL_TOKENIZER_FILE
resolved_config_file = get_file_from_repo(
@ -1704,80 +1700,38 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
vocab_files_target["tokenizer_file"] = fast_tokenizer_file
# Look for the tokenizer files
for file_id, file_name in vocab_files_target.items():
if os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None:
full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)
else:
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
if not os.path.exists(full_file_name):
logger.info(f"Didn't find file {full_file_name}. We won't load it.")
full_file_name = None
else:
full_file_name = hf_bucket_url(
pretrained_model_name_or_path,
filename=file_name,
subfolder=subfolder,
revision=revision,
mirror=None,
)
vocab_files[file_id] = full_file_name
vocab_files["tokenizer_file"] = fast_tokenizer_file
# Get files from url, cache, or disk depending on the case
resolved_vocab_files = {}
unresolved_files = []
for file_id, file_path in vocab_files.items():
print(file_id, file_path)
if file_path is None:
resolved_vocab_files[file_id] = None
else:
try:
resolved_vocab_files[file_id] = cached_path(
file_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
except FileNotFoundError as error:
if local_files_only:
unresolved_files.append(file_id)
else:
raise error
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
"pass a token having permission to this repo with `use_auth_token` or log in with "
"`huggingface-cli login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
resolved_vocab_files[file_id] = None
except ValueError:
logger.debug(f"Connection problem to access {file_path} and it wasn't found in the cache.")
resolved_vocab_files[file_id] = None
resolved_vocab_files[file_id] = cached_file(
pretrained_model_name_or_path,
file_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if len(unresolved_files) > 0:
logger.info(
@ -1797,7 +1751,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if file_id not in resolved_vocab_files:
continue
if file_path == resolved_vocab_files[file_id]:
if is_local:
logger.info(f"loading file {file_path}")
else:
logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")

View File

@ -60,6 +60,7 @@ from .hub import (
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_file,
cached_path,
default_cache_path,
define_sagemaker_information,
@ -76,6 +77,7 @@ from .hub import (
is_local_clone,
is_offline_mode,
is_remote_url,
move_cache,
send_example_telemetry,
url_to_filename,
)

View File

@ -19,11 +19,13 @@ import fnmatch
import io
import json
import os
import re
import shutil
import subprocess
import sys
import tarfile
import tempfile
import traceback
import warnings
from contextlib import contextmanager
from functools import partial
@ -34,9 +36,20 @@ from urllib.parse import urlparse
from uuid import uuid4
from zipfile import ZipFile, is_zipfile
import huggingface_hub
import requests
from filelock import FileLock
from huggingface_hub import CommitOperationAdd, HfFolder, create_commit, create_repo, list_repo_files, whoami
from huggingface_hub import (
CommitOperationAdd,
HfFolder,
create_commit,
create_repo,
hf_hub_download,
list_repo_files,
whoami,
)
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests.exceptions import HTTPError
from requests.models import Response
from transformers.utils.logging import tqdm
@ -385,21 +398,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua
class RepositoryNotFoundError(HTTPError):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
not have access to.
"""
class EntryNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
class RevisionNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
def _raise_for_status(response: Response):
"""
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
@ -628,6 +626,213 @@ def get_from_cache(
return cache_path
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
"""
Explores the cache to return the latest cached file for a given revision.
"""
if revision is None:
revision = "main"
model_id = repo_id.replace("/", "--")
model_cache = os.path.join(cache_dir, f"models--{model_id}")
if not os.path.isdir(model_cache):
# No cache for this model
return None
# Resolve refs (for instance to convert main to the associated commit sha)
cached_refs = os.listdir(os.path.join(model_cache, "refs"))
if revision in cached_refs:
with open(os.path.join(model_cache, "refs", revision)) as f:
revision = f.read()
cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
cached_file = os.path.join(model_cache, "snapshots", revision, filename)
return cached_file if os.path.isfile(cached_file) else None
# If huggingface_hub changes the class of error for this to FileNotFoundError, we will be able to avoid that in the
# future.
LOCAL_FILES_ONLY_HF_ERROR = (
"Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co "
"look-ups and downloads online, set 'local_files_only' to False."
)
# In the future, this ugly contextmanager can be removed when huggingface_hub as a released version where we can
# activate/deactivate progress bars.
@contextmanager
def _patch_hf_hub_tqdm():
"""
A context manager to make huggingface hub use the tqdm version of Transformers (which is controlled by some utils)
in logging.
"""
old_tqdm = huggingface_hub.file_download.tqdm
huggingface_hub.file_download.tqdm = tqdm
yield
huggingface_hub.file_download.tqdm = old_tqdm
def cached_file(
path_or_repo_id: Union[str, os.PathLike],
filename: str,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_missing_entries=True,
_raise_exceptions_for_connection_errors=True,
):
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
Args:
path_or_repo_id (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file.
filename (`str`):
The name of the file to locate in `path_or_repo`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
<Tip>
Passing `use_auth_token=True` is required when you want to use a private model.
</Tip>
Returns:
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
Examples:
```python
# Download a model weight from the Hub and cache it.
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
```"""
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
if subfolder is None:
subfolder = ""
path_or_repo_id = str(path_or_repo_id)
full_filename = os.path.join(subfolder, filename)
if os.path.isdir(path_or_repo_id):
resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
if not os.path.isfile(resolved_file):
if _raise_exceptions_for_missing_entries:
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
else:
return None
return resolved_file
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
user_agent = http_user_agent(user_agent)
try:
# Load from URL or cache if already cached
with _patch_hf_hub_tqdm():
resolved_file = hf_hub_download(
path_or_repo_id,
filename,
subfolder=None if len(subfolder) == 0 else subfolder,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
"pass a token having permission to this repo with `use_auth_token` or log in with "
"`huggingface-cli login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
)
except EntryNotFoundError:
if not _raise_exceptions_for_missing_entries:
return None
if revision is None:
revision = "main"
raise EnvironmentError(
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
)
except HTTPError as err:
# First we try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
if resolved_file is not None:
return resolved_file
if not _raise_exceptions_for_connection_errors:
return None
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
except ValueError as err:
# HuggingFace Hub returns a ValueError for a missing file when local_files_only=True we need to catch it here
# This could be caught above along in `EntryNotFoundError` if hf_hub sent a different error message here
if LOCAL_FILES_ONLY_HF_ERROR in err.args[0] and local_files_only and not _raise_exceptions_for_missing_entries:
return None
# Otherwise we try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
if resolved_file is not None:
return resolved_file
if not _raise_exceptions_for_connection_errors:
return None
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
return resolved_file
def get_file_from_repo(
path_or_repo: Union[str, os.PathLike],
filename: str,
@ -638,6 +843,7 @@ def get_file_from_repo(
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
):
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
@ -670,6 +876,9 @@ def get_file_from_repo(
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
<Tip>
@ -689,47 +898,20 @@ def get_file_from_repo(
# This model does not have a tokenizer config so the result will be None.
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
```"""
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
path_or_repo = str(path_or_repo)
if os.path.isdir(path_or_repo):
resolved_file = os.path.join(path_or_repo, filename)
return resolved_file if os.path.isfile(resolved_file) else None
else:
resolved_file = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=None)
try:
# Load from URL or cache if already cached
resolved_file = cached_path(
resolved_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{path_or_repo} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
"pass a token having permission to this repo with `use_auth_token` or log in with "
"`huggingface-cli login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{path_or_repo}' for available revisions."
)
except EnvironmentError:
# The repo and revision exist, but the file does not or there was a connection error fetching it.
return None
return resolved_file
return cached_file(
path_or_repo_id=path_or_repo,
filename=filename,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
def has_file(
@ -766,7 +948,7 @@ def has_file(
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
try:
_raise_for_status(r)
huggingface_hub.utils._errors._raise_for_status(r)
return True
except RepositoryNotFoundError as e:
logger.error(e)
@ -1196,3 +1378,183 @@ def get_checkpoint_shard_files(
cached_filenames.append(cached_filename)
return cached_filenames, sharded_metadata
# All what is below is for conversion between old cache format and new cache format.
def get_all_cached_files(cache_dir=None):
"""
Returns a list for all files cached with appropriate metadata.
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
else:
cache_dir = str(cache_dir)
cached_files = []
for file in os.listdir(cache_dir):
meta_path = os.path.join(cache_dir, f"{file}.json")
if not os.path.isfile(meta_path):
continue
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"].replace('"', "")
cached_files.append({"file": file, "url": url, "etag": etag})
return cached_files
def get_hub_metadata(url, token=None):
"""
Returns the commit hash and associated etag for a given url.
"""
if token is None:
token = HfFolder.get_token()
headers = {"user-agent": http_user_agent()}
headers["authorization"] = f"Bearer {token}"
r = huggingface_hub.file_download._request_with_retry(
method="HEAD", url=url, headers=headers, allow_redirects=False
)
huggingface_hub.file_download._raise_for_status(r)
commit_hash = r.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT)
etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")
if etag is not None:
etag = huggingface_hub.file_download._normalize_etag(etag)
return etag, commit_hash
def extract_info_from_url(url):
"""
Extract repo_name, revision and filename from an url.
"""
search = re.search(r"^https://huggingface\.co/(.*)/resolve/([^/]*)/(.*)$", url)
if search is None:
return None
repo, revision, filename = search.groups()
cache_repo = "--".join(["models"] + repo.split("/"))
return {"repo": cache_repo, "revision": revision, "filename": filename}
def clean_files_for(file):
"""
Remove, if they exist, file, file.json and file.lock
"""
for f in [file, f"{file}.json", f"{file}.lock"]:
if os.path.isfile(f):
os.remove(f)
def move_to_new_cache(file, repo, filename, revision, etag, commit_hash):
"""
Move file to repo following the new huggingface hub cache organization.
"""
os.makedirs(repo, exist_ok=True)
# refs
os.makedirs(os.path.join(repo, "refs"), exist_ok=True)
if revision != commit_hash:
ref_path = os.path.join(repo, "refs", revision)
with open(ref_path, "w") as f:
f.write(commit_hash)
# blobs
os.makedirs(os.path.join(repo, "blobs"), exist_ok=True)
# TODO: replace copy by move when all works well.
blob_path = os.path.join(repo, "blobs", etag)
shutil.move(file, blob_path)
# snapshots
os.makedirs(os.path.join(repo, "snapshots"), exist_ok=True)
os.makedirs(os.path.join(repo, "snapshots", commit_hash), exist_ok=True)
pointer_path = os.path.join(repo, "snapshots", commit_hash, filename)
huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
clean_files_for(file)
def move_cache(cache_dir=None, token=None):
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if token is None:
token = HfFolder.get_token()
cached_files = get_all_cached_files(cache_dir=cache_dir)
print(f"Moving {len(cached_files)} files to the new cache system")
hub_metadata = {}
for file_info in tqdm(cached_files):
url = file_info.pop("url")
if url not in hub_metadata:
try:
hub_metadata[url] = get_hub_metadata(url, token=token)
except requests.HTTPError:
continue
etag, commit_hash = hub_metadata[url]
if etag is None or commit_hash is None:
continue
if file_info["etag"] != etag:
# Cached file is not up to date, we just throw it as a new version will be downloaded anyway.
clean_files_for(os.path.join(cache_dir, file_info["file"]))
continue
url_info = extract_info_from_url(url)
if url_info is None:
# Not a file from huggingface.co
continue
repo = os.path.join(cache_dir, url_info["repo"])
move_to_new_cache(
file=os.path.join(cache_dir, file_info["file"]),
repo=repo,
filename=url_info["filename"],
revision=url_info["revision"],
etag=etag,
commit_hash=commit_hash,
)
cache_version_file = os.path.join(TRANSFORMERS_CACHE, "version.txt")
if not os.path.isfile(cache_version_file):
cache_version = 0
else:
with open(cache_version_file) as f:
cache_version = int(f.read())
if cache_version < 1:
if is_offline_mode():
logger.warn(
"You are offline and the cache for model files in Transformers v4.22.0 has been updated while your local "
"cache seems to be the one of a previous version. It is very likely that all your calls to any "
"`from_pretrained()` method will fail. Remove the offline mode and enable internet connection to have "
"your cache be updated automatically, then you can go back to offline mode."
)
else:
logger.warn(
"The cache for model files in Transformers v4.22.0 has been udpated. Migrating your old cache. This is a "
"one-time only operation. You can interrupt this and resume the migration later on by calling "
"`transformers.utils.move_cache()`."
)
try:
move_cache()
except Exception as e:
trace = "\n".join(traceback.format_tb(e.__traceback__))
logger.error(
f"There was a problem when trying to move your cache:\n\n{trace}\n\nPlease file an issue at "
"https://github.com/huggingface/transformers/issues/new/choose and copy paste this whole message and we "
"will do our best to help."
)
try:
os.makedirs(TRANSFORMERS_CACHE, exist_ok=True)
with open(cache_version_file, "w") as f:
f.write("1")
except Exception:
logger.warn(
f"There was a problem when trying to write in your cache folder ({TRANSFORMERS_CACHE}). You should set "
"the environment variable TRANSFORMERS_CACHE to a writable directory."
)

View File

@ -345,14 +345,14 @@ class ConfigTestUtils(unittest.TestCase):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()

View File

@ -170,13 +170,13 @@ class FeatureExtractorUtilTester(unittest.TestCase):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# This check we did call the fake head request
mock_head.assert_called()

View File

@ -2925,14 +2925,14 @@ class ModelUtilsTest(TestCasePlus):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()

View File

@ -1922,14 +1922,14 @@ class UtilsFunctionsTest(unittest.TestCase):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()

View File

@ -3829,14 +3829,14 @@ class TokenizerUtilTester(unittest.TestCase):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()