Sharding fails in TF when absolute scope was modified if `.` in layer name (#19124)

* simplify loop

* fix layer map split

* update

* update for special variables

* add rag test

* fixup

* revert change : for next PR
This commit is contained in:
Arthur 2022-10-14 18:34:33 +02:00 committed by GitHub
parent 614f7d28a8
commit 2bd2de62c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 9 deletions

View File

@ -707,8 +707,15 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
# the weight, we have to get rid of the first prefix of the name of the layer.
model_keys = set("/".join(k.name.split("/")[1:]) for k in model.weights)
model_layer_map = {"/".join(k.name.split("/")[1:]): i for i, k in enumerate(model.weights)}
model_keys = set()
model_layer_map = dict()
for i, k in enumerate(model.weights):
if "model." in k.name or len(k.name.split("/")) == 1:
layer_name = k.name
else:
layer_name = "/".join(k.name.split("/")[1:])
model_keys.add(layer_name)
model_layer_map[layer_name] = i
for shard_file in shard_files:
state_dict = tf.io.read_file(shard_file)
@ -2211,17 +2218,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
)
for shard_file, shard in shards.items():
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
save_attributes_to_hdf5_group(
shard_file,
"layer_names",
["/".join(layer.name.split("/")[1:]).encode("utf8") for layer in shard],
)
layers = []
for layer in sorted(shard, key=lambda x: x.name):
if "model." in layer.name or len(layer.name.split("/")) == 1:
layer_name = layer.name
print(layer_name)
else:
layer_name = "/".join(layer.name.split("/")[1:])
param_dset = shard_file.create_dataset(
"/".join(layer.name.split("/")[1:]), layer.numpy().shape, dtype=layer.numpy().dtype
layer_name, layer.numpy().shape, dtype=layer.numpy().dtype
)
param_dset[:] = layer.numpy()
layers.append(layer_name.encode("utf8"))
save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
if push_to_hub:
self._upload_modified_files(

View File

@ -77,9 +77,11 @@ if is_tf_available():
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
RagRetriever,
TFAutoModel,
TFAutoModelForSequenceClassification,
TFBertModel,
TFRagModel,
TFSharedEmbeddings,
)
from transformers.generation_tf_utils import (
@ -2167,6 +2169,18 @@ class UtilsFunctionsTest(unittest.TestCase):
},
)
@slow
def test_special_layer_name_shardind(self):
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
with tempfile.TemporaryDirectory() as tmp_dir:
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size)
ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever)
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
def test_checkpoint_sharding_local(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")