[WIP] Hard error when ignoring tensors. (#27484)

* [WIP] Hard error when ignoring tensors.

* Better selection/error when saving a checkpoint.

- Find all names we should normally drop (those are in the transformers
  config)
- Find all disjoint tensors (for those we can safely trigger a copy to
  get rid of the sharing before saving)
- Clone those disjoint tensors getting rid of the issue
- Find all identical names (those should be declared in the config
  but we try to find them all anyway.)
- For all identical names:
  - If they are in the config, just ignore them everything is fine
  - If they are not, warn about them.
- For all remainder tensors which are shared yet neither identical NOR
  disjoint. raise a hard error.

* Adding a failing test on `main` that passes here.

* We don't need to keep the subfolder logic in this test.

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Nicolas Patry 2024-02-05 09:17:24 +01:00 committed by GitHub
parent 0466fd5ca2
commit 2da28c4b41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 112 additions and 14 deletions

View File

@ -29,7 +29,7 @@ import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from zipfile import is_zipfile
import torch
@ -570,6 +570,65 @@ def set_initialized_submodules(model, state_dict_keys):
return not_initialized_submodules
def _end_ptr(tensor: torch.Tensor) -> int:
# extract the end of the pointer if the tensor is a slice of a bigger tensor
if tensor.nelement():
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
else:
stop = tensor.data_ptr()
return stop
def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
filtered_tensors = []
for shared in tensors:
if len(shared) < 2:
filtered_tensors.append(shared)
continue
areas = []
for name in shared:
tensor = state_dict[name]
areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
areas.sort()
_, last_stop, last_name = areas[0]
filtered_tensors.append({last_name})
for start, stop, name in areas[1:]:
if start >= last_stop:
filtered_tensors.append({name})
else:
filtered_tensors[-1].add(name)
last_stop = stop
disjoint_tensors = []
shared_tensors = []
for tensors in filtered_tensors:
if len(tensors) == 1:
disjoint_tensors.append(tensors.pop())
else:
shared_tensors.append(tensors)
return shared_tensors, disjoint_tensors
def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
shared_tensors = []
identical = []
for shared in tensors:
if len(shared) < 2:
continue
areas = collections.defaultdict(set)
for name in shared:
tensor = state_dict[name]
area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor))
areas[area].add(name)
if len(areas) == 1:
identical.append(shared)
else:
shared_tensors.append(shared)
return shared_tensors, identical
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
@ -2382,6 +2441,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set()
error_names = set()
to_delete_names = set()
for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent.
@ -2392,25 +2453,42 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if matches_pattern and name in state_dict:
found += 1
if found < len(names):
del state_dict[name]
to_delete_names.add(name)
# We are entering a place where the weights and the transformers configuration do NOT match.
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
for name in disjoint_names:
state_dict[name] = state_dict[name].clone()
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
shared_names, identical_names = _find_identical(shared_names, state_dict)
# delete tensors that have identical storage
for inames in identical_names:
known = inames.intersection(to_delete_names)
for name in known:
del state_dict[name]
unknown = sorted(inames.difference(to_delete_names))
for name in unknown[1:]:
del state_dict[name]
warn_names.add(name)
error_names.update(shared_names)
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
found = 0
for name in names:
if name in state_dict:
found += 1
if found > 1:
del state_dict[name]
warn_names.add(name)
if len(warn_names) > 0:
logger.warning_once(
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
)
if len(error_names) > 0:
raise RuntimeError(
f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.",
)
# Shard the model if it is too big.
if not _hf_peft_config_loaded:
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME

View File

@ -257,6 +257,26 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_manually_shared_disjointed_tensors_optimum(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)
# Let's fuse qkv
attn = model.encoder.layer[0].attention.self
q = attn.query.weight
k = attn.key.weight
v = attn.value.weight
# Force some shared storage
qkv = torch.stack([q, k, v], dim=0)
attn.query.weight = torch.nn.Parameter(qkv[0])
attn.key.weight = torch.nn.Parameter(qkv[1])
attn.value.weight = torch.nn.Parameter(qkv[2])
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model_loaded = BertModel.from_pretrained(tmp_dir)
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_from_pretrained_subfolder_sharded(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)