Fix offload disk for loading derivated model checkpoint into base model (#27253)
* fix * style * add test
This commit is contained in:
parent
b71c38a094
commit
1ac599d90f
|
@ -3793,8 +3793,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
else:
|
||||
folder = None
|
||||
if device_map is not None and is_safetensors:
|
||||
param_device_map = expand_device_map(device_map, original_loaded_keys)
|
||||
|
||||
param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
|
||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||
if sharded_metadata is None:
|
||||
archive_file = (
|
||||
|
@ -3806,9 +3805,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
else:
|
||||
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
|
||||
offload_index = {
|
||||
p: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
||||
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
||||
for p, f in weight_map.items()
|
||||
if param_device_map[p] == "disk"
|
||||
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
|
||||
}
|
||||
|
||||
if state_dict is not None:
|
||||
|
@ -3842,7 +3841,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
state_dict_index = None
|
||||
|
||||
if is_sharded_safetensors:
|
||||
disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata)
|
||||
disk_only_shard_files = get_disk_only_shard_files(
|
||||
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
|
||||
)
|
||||
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
|
||||
else:
|
||||
disk_only_shard_files = []
|
||||
|
@ -4576,11 +4577,12 @@ def unwrap_model(model: nn.Module) -> nn.Module:
|
|||
return model
|
||||
|
||||
|
||||
def expand_device_map(device_map, param_names):
|
||||
def expand_device_map(device_map, param_names, start_prefix):
|
||||
"""
|
||||
Expand a device map to return the correspondance parameter name to device.
|
||||
"""
|
||||
new_device_map = {}
|
||||
param_names = [p[len(start_prefix) :] for p in param_names if p.startswith(start_prefix)]
|
||||
for module, device in device_map.items():
|
||||
new_device_map.update(
|
||||
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
||||
|
@ -4588,12 +4590,16 @@ def expand_device_map(device_map, param_names):
|
|||
return new_device_map
|
||||
|
||||
|
||||
def get_disk_only_shard_files(device_map, sharded_metadata):
|
||||
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
|
||||
"""
|
||||
Returns the list of shard files containing only weights offloaded to disk.
|
||||
"""
|
||||
|
||||
weight_map = {
|
||||
p[len(start_prefix) :]: v for p, v in sharded_metadata["weight_map"].items() if p.startswith(start_prefix)
|
||||
}
|
||||
files_content = collections.defaultdict(list)
|
||||
for weight_name, filename in sharded_metadata["weight_map"].items():
|
||||
for weight_name, filename in weight_map.items():
|
||||
while len(weight_name) > 0 and weight_name not in device_map:
|
||||
weight_name = ".".join(weight_name.split(".")[:-1])
|
||||
files_content[filename].append(device_map[weight_name])
|
||||
|
|
|
@ -750,6 +750,46 @@ class ModelUtilsTest(TestCasePlus):
|
|||
|
||||
self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@require_torch_accelerator
|
||||
def test_from_pretrained_disk_offload_derived_to_base_model(self):
|
||||
derived_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
|
||||
device_map = {
|
||||
"wte": 0,
|
||||
"wpe": 0,
|
||||
"h.0": "cpu",
|
||||
"h.1": "cpu",
|
||||
"h.2": "cpu",
|
||||
"h.3": "disk",
|
||||
"h.4": "disk",
|
||||
"ln_f": 0,
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
inputs = torch.tensor([[1, 2, 3]]).to(0)
|
||||
derived_model.save_pretrained(tmp_dir, use_safetensors=True)
|
||||
base_model = AutoModel.from_pretrained(tmp_dir)
|
||||
outputs1 = base_model.to(0)(inputs)
|
||||
|
||||
# with disk offload
|
||||
offload_folder = os.path.join(tmp_dir, "offload")
|
||||
base_model_with_offload = AutoModel.from_pretrained(
|
||||
tmp_dir, device_map=device_map, offload_folder=offload_folder
|
||||
)
|
||||
outputs2 = base_model_with_offload(inputs)
|
||||
self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
|
||||
|
||||
# With state dict temp offload
|
||||
new_model_with_offload = AutoModel.from_pretrained(
|
||||
tmp_dir,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=True,
|
||||
)
|
||||
outputs2 = new_model_with_offload(inputs)
|
||||
self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
|
||||
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
|
|
Loading…
Reference in New Issue