Fix offload disk for loading derivated model checkpoint into base model (#27253)

* fix

* style

* add test
This commit is contained in:
Marc Sun 2023-11-15 20:58:08 +01:00 committed by GitHub
parent b71c38a094
commit 1ac599d90f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 8 deletions

View File

@ -3793,8 +3793,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
folder = None
if device_map is not None and is_safetensors:
param_device_map = expand_device_map(device_map, original_loaded_keys)
param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
archive_file = (
@ -3806,9 +3805,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
offload_index = {
p: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items()
if param_device_map[p] == "disk"
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
}
if state_dict is not None:
@ -3842,7 +3841,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict_index = None
if is_sharded_safetensors:
disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata)
disk_only_shard_files = get_disk_only_shard_files(
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
else:
disk_only_shard_files = []
@ -4576,11 +4577,12 @@ def unwrap_model(model: nn.Module) -> nn.Module:
return model
def expand_device_map(device_map, param_names):
def expand_device_map(device_map, param_names, start_prefix):
"""
Expand a device map to return the correspondance parameter name to device.
"""
new_device_map = {}
param_names = [p[len(start_prefix) :] for p in param_names if p.startswith(start_prefix)]
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
@ -4588,12 +4590,16 @@ def expand_device_map(device_map, param_names):
return new_device_map
def get_disk_only_shard_files(device_map, sharded_metadata):
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
"""
Returns the list of shard files containing only weights offloaded to disk.
"""
weight_map = {
p[len(start_prefix) :]: v for p, v in sharded_metadata["weight_map"].items() if p.startswith(start_prefix)
}
files_content = collections.defaultdict(list)
for weight_name, filename in sharded_metadata["weight_map"].items():
for weight_name, filename in weight_map.items():
while len(weight_name) > 0 and weight_name not in device_map:
weight_name = ".".join(weight_name.split(".")[:-1])
files_content[filename].append(device_map[weight_name])

View File

@ -750,6 +750,46 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
@require_accelerate
@mark.accelerate_tests
@require_torch_accelerator
def test_from_pretrained_disk_offload_derived_to_base_model(self):
derived_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
device_map = {
"wte": 0,
"wpe": 0,
"h.0": "cpu",
"h.1": "cpu",
"h.2": "cpu",
"h.3": "disk",
"h.4": "disk",
"ln_f": 0,
}
with tempfile.TemporaryDirectory() as tmp_dir:
inputs = torch.tensor([[1, 2, 3]]).to(0)
derived_model.save_pretrained(tmp_dir, use_safetensors=True)
base_model = AutoModel.from_pretrained(tmp_dir)
outputs1 = base_model.to(0)(inputs)
# with disk offload
offload_folder = os.path.join(tmp_dir, "offload")
base_model_with_offload = AutoModel.from_pretrained(
tmp_dir, device_map=device_map, offload_folder=offload_folder
)
outputs2 = base_model_with_offload(inputs)
self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
# With state dict temp offload
new_model_with_offload = AutoModel.from_pretrained(
tmp_dir,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=True,
)
outputs2 = new_model_with_offload(inputs)
self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()