[WIP] Hard error when ignoring tensors. (#27484)
* [WIP] Hard error when ignoring tensors. * Better selection/error when saving a checkpoint. - Find all names we should normally drop (those are in the transformers config) - Find all disjoint tensors (for those we can safely trigger a copy to get rid of the sharing before saving) - Clone those disjoint tensors getting rid of the issue - Find all identical names (those should be declared in the config but we try to find them all anyway.) - For all identical names: - If they are in the config, just ignore them everything is fine - If they are not, warn about them. - For all remainder tensors which are shared yet neither identical NOR disjoint. raise a hard error. * Adding a failing test on `main` that passes here. * We don't need to keep the subfolder logic in this test. * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
0466fd5ca2
commit
2da28c4b41
|
@ -29,7 +29,7 @@ import warnings
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import torch
|
||||
|
@ -570,6 +570,65 @@ def set_initialized_submodules(model, state_dict_keys):
|
|||
return not_initialized_submodules
|
||||
|
||||
|
||||
def _end_ptr(tensor: torch.Tensor) -> int:
|
||||
# extract the end of the pointer if the tensor is a slice of a bigger tensor
|
||||
if tensor.nelement():
|
||||
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
|
||||
else:
|
||||
stop = tensor.data_ptr()
|
||||
return stop
|
||||
|
||||
|
||||
def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
|
||||
filtered_tensors = []
|
||||
for shared in tensors:
|
||||
if len(shared) < 2:
|
||||
filtered_tensors.append(shared)
|
||||
continue
|
||||
|
||||
areas = []
|
||||
for name in shared:
|
||||
tensor = state_dict[name]
|
||||
areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
|
||||
areas.sort()
|
||||
|
||||
_, last_stop, last_name = areas[0]
|
||||
filtered_tensors.append({last_name})
|
||||
for start, stop, name in areas[1:]:
|
||||
if start >= last_stop:
|
||||
filtered_tensors.append({name})
|
||||
else:
|
||||
filtered_tensors[-1].add(name)
|
||||
last_stop = stop
|
||||
disjoint_tensors = []
|
||||
shared_tensors = []
|
||||
for tensors in filtered_tensors:
|
||||
if len(tensors) == 1:
|
||||
disjoint_tensors.append(tensors.pop())
|
||||
else:
|
||||
shared_tensors.append(tensors)
|
||||
return shared_tensors, disjoint_tensors
|
||||
|
||||
|
||||
def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
|
||||
shared_tensors = []
|
||||
identical = []
|
||||
for shared in tensors:
|
||||
if len(shared) < 2:
|
||||
continue
|
||||
|
||||
areas = collections.defaultdict(set)
|
||||
for name in shared:
|
||||
tensor = state_dict[name]
|
||||
area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor))
|
||||
areas[area].add(name)
|
||||
if len(areas) == 1:
|
||||
identical.append(shared)
|
||||
else:
|
||||
shared_tensors.append(shared)
|
||||
return shared_tensors, identical
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
old_keys = []
|
||||
|
@ -2382,6 +2441,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
# These are all the pointers of shared tensors.
|
||||
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
||||
warn_names = set()
|
||||
error_names = set()
|
||||
to_delete_names = set()
|
||||
for names in shared_ptrs.values():
|
||||
# Removing the keys which are declared as known duplicates on
|
||||
# load. This allows to make sure the name which is kept is consistent.
|
||||
|
@ -2392,25 +2453,42 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if matches_pattern and name in state_dict:
|
||||
found += 1
|
||||
if found < len(names):
|
||||
del state_dict[name]
|
||||
to_delete_names.add(name)
|
||||
# We are entering a place where the weights and the transformers configuration do NOT match.
|
||||
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
|
||||
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
|
||||
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
|
||||
for name in disjoint_names:
|
||||
state_dict[name] = state_dict[name].clone()
|
||||
|
||||
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
||||
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
||||
# the key back leading to random tensor. A proper warning will be shown
|
||||
# during reload (if applicable), but since the file is not necessarily compatible with
|
||||
# the config, better show a proper warning.
|
||||
shared_names, identical_names = _find_identical(shared_names, state_dict)
|
||||
# delete tensors that have identical storage
|
||||
for inames in identical_names:
|
||||
known = inames.intersection(to_delete_names)
|
||||
for name in known:
|
||||
del state_dict[name]
|
||||
unknown = sorted(inames.difference(to_delete_names))
|
||||
for name in unknown[1:]:
|
||||
del state_dict[name]
|
||||
warn_names.add(name)
|
||||
|
||||
error_names.update(shared_names)
|
||||
|
||||
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
||||
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
||||
# the key back leading to random tensor. A proper warning will be shown
|
||||
# during reload (if applicable), but since the file is not necessarily compatible with
|
||||
# the config, better show a proper warning.
|
||||
found = 0
|
||||
for name in names:
|
||||
if name in state_dict:
|
||||
found += 1
|
||||
if found > 1:
|
||||
del state_dict[name]
|
||||
warn_names.add(name)
|
||||
if len(warn_names) > 0:
|
||||
logger.warning_once(
|
||||
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
|
||||
)
|
||||
|
||||
if len(error_names) > 0:
|
||||
raise RuntimeError(
|
||||
f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.",
|
||||
)
|
||||
|
||||
# Shard the model if it is too big.
|
||||
if not _hf_peft_config_loaded:
|
||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
|
|
|
@ -257,6 +257,26 @@ class ModelUtilsTest(TestCasePlus):
|
|||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_manually_shared_disjointed_tensors_optimum(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
model = BertModel(config)
|
||||
|
||||
# Let's fuse qkv
|
||||
attn = model.encoder.layer[0].attention.self
|
||||
q = attn.query.weight
|
||||
k = attn.key.weight
|
||||
v = attn.value.weight
|
||||
# Force some shared storage
|
||||
qkv = torch.stack([q, k, v], dim=0)
|
||||
attn.query.weight = torch.nn.Parameter(qkv[0])
|
||||
attn.key.weight = torch.nn.Parameter(qkv[1])
|
||||
attn.value.weight = torch.nn.Parameter(qkv[2])
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
model_loaded = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_subfolder_sharded(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
model = BertModel(config)
|
||||
|
|
Loading…
Reference in New Issue