[From pretrained] Allow download from subfolder inside model repo (#18184)

* add first generation tutorial

* [from_pretrained] Allow loading models from subfolders

* remove gen file

* add doc strings

* allow download from subfolder

* add tests

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply comments

* correct doc string

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2022-07-19 11:53:53 +02:00 committed by GitHub
parent ce0152819d
commit 3bb6356d4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 137 additions and 23 deletions

View File

@ -494,6 +494,9 @@ class PretrainedConfig(PushToHubMixin):
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
@ -577,6 +580,7 @@ class PretrainedConfig(PushToHubMixin):
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
@ -589,16 +593,22 @@ class PretrainedConfig(PushToHubMixin):
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
pretrained_model_name_or_path
):
config_file = pretrained_model_name_or_path
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
if os.path.isdir(os.path.join(pretrained_model_name_or_path, subfolder)):
config_file = os.path.join(pretrained_model_name_or_path, subfolder, configuration_file)
else:
config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
pretrained_model_name_or_path,
filename=configuration_file,
revision=revision,
subfolder=subfolder if len(subfolder) > 0 else None,
mirror=None,
)
try:

View File

@ -1691,6 +1691,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
`True` when there is some disk offload.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
@ -1777,6 +1780,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", None)
subfolder = kwargs.pop("subfolder", "")
if device_map is not None:
if low_cpu_mem_usage is None:
@ -1820,6 +1824,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
@ -1837,32 +1842,38 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
if from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
):
# Load from a TF 1.0 checkpoint in priority if from_tf
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
elif from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
):
# Load from a TF 2.0 checkpoint in priority if from_tf
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
elif from_flax and os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
elif from_flax and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
):
# Load from a Flax checkpoint in priority if from_flax
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
"weights."
)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
@ -1873,15 +1884,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
pretrained_model_name_or_path
):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
if not from_tf:
raise ValueError(
f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set "
"from_tf to True to load from this checkpoint."
)
archive_file = pretrained_model_name_or_path + ".index"
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
else:
# set correct filename
if from_tf:
@ -1892,7 +1905,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
filename = WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror
pretrained_model_name_or_path,
filename=filename,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
)
try:
@ -1930,6 +1947,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
filename=WEIGHTS_INDEX_NAME,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
)
resolved_archive_file = cached_path(
archive_file,
@ -2016,6 +2034,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
user_agent=user_agent,
revision=revision,
mirror=mirror,
subfolder=subfolder,
)
# load pt weights early so that we know which dtype to init the model under

View File

@ -1142,6 +1142,7 @@ def get_checkpoint_shard_files(
user_agent=None,
revision=None,
mirror=None,
subfolder="",
):
"""
For a given model:
@ -1167,14 +1168,18 @@ def get_checkpoint_shard_files(
# First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path):
shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
return shard_filenames, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames = []
for shard_filename in shard_filenames:
shard_url = hf_bucket_url(
pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror
pretrained_model_name_or_path,
filename=shard_filename,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
)
try:

View File

@ -157,6 +157,17 @@ class ConfigTester(object):
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def create_and_test_config_from_and_save_pretrained_subfolder(self):
config_first = self.config_class(**self.inputs_dict)
subfolder = "test"
with tempfile.TemporaryDirectory() as tmpdirname:
sub_tmpdirname = os.path.join(tmpdirname, subfolder)
config_first.save_pretrained(sub_tmpdirname)
config_second = self.config_class.from_pretrained(tmpdirname, subfolder=subfolder)
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def create_and_test_config_with_num_labels(self):
config = self.config_class(**self.inputs_dict, num_labels=5)
self.parent.assertEqual(len(config.id2label), 5)
@ -197,6 +208,7 @@ class ConfigTester(object):
self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained()
self.create_and_test_config_from_and_save_pretrained_subfolder()
self.create_and_test_config_with_num_labels()
self.check_config_can_be_init_without_params()
self.check_config_arguments_init()
@ -308,6 +320,15 @@ class ConfigTestUtils(unittest.TestCase):
f" {', '.join(keys_with_defaults)}."
)
def test_from_pretrained_subfolder(self):
with self.assertRaises(OSError):
# config is in subfolder, the following should not work without specifying the subfolder
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
self.assertIsNotNone(config)
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()

View File

@ -2503,6 +2503,15 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()
def check_models_equal(model1, model2):
models_are_equal = True
for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
if model1_p.data.ne(model2_p.data).sum() > 0:
models_are_equal = False
return models_are_equal
@require_torch
class ModelUtilsTest(TestCasePlus):
@slow
@ -2531,6 +2540,56 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)
def test_model_from_pretrained_subfolder(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)
subfolder = "bert"
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(os.path.join(tmp_dir, subfolder))
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(tmp_dir)
model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_from_pretrained_subfolder_sharded(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)
subfolder = "bert"
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(tmp_dir)
model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_from_pretrained_hub_subfolder(self):
subfolder = "bert"
model_id = "hf-internal-testing/tiny-random-bert-subfolder"
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(model_id)
model = BertModel.from_pretrained(model_id, subfolder=subfolder)
self.assertIsNotNone(model)
def test_model_from_pretrained_hub_subfolder_sharded(self):
subfolder = "bert"
model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(model_id)
model = BertModel.from_pretrained(model_id, subfolder=subfolder)
self.assertIsNotNone(model)
def test_model_from_pretrained_with_different_pretrained_model_name(self):
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
self.assertIsNotNone(model)