[From pretrained] Allow download from subfolder inside model repo (#18184)
* add first generation tutorial * [from_pretrained] Allow loading models from subfolders * remove gen file * add doc strings * allow download from subfolder * add tests * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply comments * correct doc string Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
ce0152819d
commit
3bb6356d4d
|
@ -494,6 +494,9 @@ class PretrainedConfig(PushToHubMixin):
|
|||
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
|
||||
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
|
||||
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
||||
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
||||
|
@ -577,6 +580,7 @@ class PretrainedConfig(PushToHubMixin):
|
|||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
|
||||
|
@ -589,16 +593,22 @@ class PretrainedConfig(PushToHubMixin):
|
|||
local_files_only = True
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
|
||||
pretrained_model_name_or_path
|
||||
):
|
||||
config_file = pretrained_model_name_or_path
|
||||
else:
|
||||
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
|
||||
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
|
||||
if os.path.isdir(os.path.join(pretrained_model_name_or_path, subfolder)):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, subfolder, configuration_file)
|
||||
else:
|
||||
config_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
|
||||
pretrained_model_name_or_path,
|
||||
filename=configuration_file,
|
||||
revision=revision,
|
||||
subfolder=subfolder if len(subfolder) > 0 else None,
|
||||
mirror=None,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
@ -1691,6 +1691,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
|
||||
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
|
||||
`True` when there is some disk offload.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
|
@ -1777,6 +1780,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
max_memory = kwargs.pop("max_memory", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", None)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
|
||||
if device_map is not None:
|
||||
if low_cpu_mem_usage is None:
|
||||
|
@ -1820,6 +1824,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
**kwargs,
|
||||
|
@ -1837,32 +1842,38 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if pretrained_model_name_or_path is not None:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
|
||||
if from_tf and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
||||
):
|
||||
# Load from a TF 1.0 checkpoint in priority if from_tf
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||
elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
||||
elif from_tf and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
|
||||
):
|
||||
# Load from a TF 2.0 checkpoint in priority if from_tf
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||
elif from_flax and os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
|
||||
elif from_flax and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||
):
|
||||
# Load from a Flax checkpoint in priority if from_flax
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)):
|
||||
# Load from a sharded PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
||||
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
|
||||
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
|
||||
"weights."
|
||||
)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
|
||||
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
|
||||
|
@ -1873,15 +1884,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
|
||||
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
|
||||
pretrained_model_name_or_path
|
||||
):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
|
||||
if not from_tf:
|
||||
raise ValueError(
|
||||
f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set "
|
||||
"from_tf to True to load from this checkpoint."
|
||||
)
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
|
||||
else:
|
||||
# set correct filename
|
||||
if from_tf:
|
||||
|
@ -1892,7 +1905,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
filename = WEIGHTS_NAME
|
||||
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror
|
||||
pretrained_model_name_or_path,
|
||||
filename=filename,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
subfolder=subfolder if len(subfolder) > 0 else None,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -1930,6 +1947,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
filename=WEIGHTS_INDEX_NAME,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
subfolder=subfolder if len(subfolder) > 0 else None,
|
||||
)
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
|
@ -2016,6 +2034,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
|
||||
# load pt weights early so that we know which dtype to init the model under
|
||||
|
|
|
@ -1142,6 +1142,7 @@ def get_checkpoint_shard_files(
|
|||
user_agent=None,
|
||||
revision=None,
|
||||
mirror=None,
|
||||
subfolder="",
|
||||
):
|
||||
"""
|
||||
For a given model:
|
||||
|
@ -1167,14 +1168,18 @@ def get_checkpoint_shard_files(
|
|||
|
||||
# First, let's deal with local folder.
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
|
||||
shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
|
||||
return shard_filenames, sharded_metadata
|
||||
|
||||
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||
cached_filenames = []
|
||||
for shard_filename in shard_filenames:
|
||||
shard_url = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror
|
||||
pretrained_model_name_or_path,
|
||||
filename=shard_filename,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
subfolder=subfolder if len(subfolder) > 0 else None,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
@ -157,6 +157,17 @@ class ConfigTester(object):
|
|||
|
||||
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def create_and_test_config_from_and_save_pretrained_subfolder(self):
|
||||
config_first = self.config_class(**self.inputs_dict)
|
||||
|
||||
subfolder = "test"
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sub_tmpdirname = os.path.join(tmpdirname, subfolder)
|
||||
config_first.save_pretrained(sub_tmpdirname)
|
||||
config_second = self.config_class.from_pretrained(tmpdirname, subfolder=subfolder)
|
||||
|
||||
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def create_and_test_config_with_num_labels(self):
|
||||
config = self.config_class(**self.inputs_dict, num_labels=5)
|
||||
self.parent.assertEqual(len(config.id2label), 5)
|
||||
|
@ -197,6 +208,7 @@ class ConfigTester(object):
|
|||
self.create_and_test_config_to_json_string()
|
||||
self.create_and_test_config_to_json_file()
|
||||
self.create_and_test_config_from_and_save_pretrained()
|
||||
self.create_and_test_config_from_and_save_pretrained_subfolder()
|
||||
self.create_and_test_config_with_num_labels()
|
||||
self.check_config_can_be_init_without_params()
|
||||
self.check_config_arguments_init()
|
||||
|
@ -308,6 +320,15 @@ class ConfigTestUtils(unittest.TestCase):
|
|||
f" {', '.join(keys_with_defaults)}."
|
||||
)
|
||||
|
||||
def test_from_pretrained_subfolder(self):
|
||||
with self.assertRaises(OSError):
|
||||
# config is in subfolder, the following should not work without specifying the subfolder
|
||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
|
||||
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
|
||||
|
||||
self.assertIsNotNone(config)
|
||||
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
|
|
|
@ -2503,6 +2503,15 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
|||
return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()
|
||||
|
||||
|
||||
def check_models_equal(model1, model2):
|
||||
models_are_equal = True
|
||||
for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
|
||||
if model1_p.data.ne(model2_p.data).sum() > 0:
|
||||
models_are_equal = False
|
||||
|
||||
return models_are_equal
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModelUtilsTest(TestCasePlus):
|
||||
@slow
|
||||
|
@ -2531,6 +2540,56 @@ class ModelUtilsTest(TestCasePlus):
|
|||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(model.config, config)
|
||||
|
||||
def test_model_from_pretrained_subfolder(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
model = BertModel(config)
|
||||
|
||||
subfolder = "bert"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(os.path.join(tmp_dir, subfolder))
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_subfolder_sharded(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
model = BertModel(config)
|
||||
|
||||
subfolder = "bert"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_hub_subfolder(self):
|
||||
subfolder = "bert"
|
||||
model_id = "hf-internal-testing/tiny-random-bert-subfolder"
|
||||
with self.assertRaises(OSError):
|
||||
_ = BertModel.from_pretrained(model_id)
|
||||
|
||||
model = BertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_model_from_pretrained_hub_subfolder_sharded(self):
|
||||
subfolder = "bert"
|
||||
model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
|
||||
with self.assertRaises(OSError):
|
||||
_ = BertModel.from_pretrained(model_id)
|
||||
|
||||
model = BertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_model_from_pretrained_with_different_pretrained_model_name(self):
|
||||
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||
self.assertIsNotNone(model)
|
||||
|
|
Loading…
Reference in New Issue