Sharding fails in TF when absolute scope was modified if `.` in layer name (#19124)
* simplify loop * fix layer map split * update * update for special variables * add rag test * fixup * revert change : for next PR
This commit is contained in:
parent
614f7d28a8
commit
2bd2de62c9
|
@ -707,8 +707,15 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
|
|||
|
||||
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
|
||||
# the weight, we have to get rid of the first prefix of the name of the layer.
|
||||
model_keys = set("/".join(k.name.split("/")[1:]) for k in model.weights)
|
||||
model_layer_map = {"/".join(k.name.split("/")[1:]): i for i, k in enumerate(model.weights)}
|
||||
model_keys = set()
|
||||
model_layer_map = dict()
|
||||
for i, k in enumerate(model.weights):
|
||||
if "model." in k.name or len(k.name.split("/")) == 1:
|
||||
layer_name = k.name
|
||||
else:
|
||||
layer_name = "/".join(k.name.split("/")[1:])
|
||||
model_keys.add(layer_name)
|
||||
model_layer_map[layer_name] = i
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = tf.io.read_file(shard_file)
|
||||
|
@ -2211,17 +2218,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
)
|
||||
for shard_file, shard in shards.items():
|
||||
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
|
||||
save_attributes_to_hdf5_group(
|
||||
shard_file,
|
||||
"layer_names",
|
||||
["/".join(layer.name.split("/")[1:]).encode("utf8") for layer in shard],
|
||||
)
|
||||
|
||||
layers = []
|
||||
for layer in sorted(shard, key=lambda x: x.name):
|
||||
if "model." in layer.name or len(layer.name.split("/")) == 1:
|
||||
layer_name = layer.name
|
||||
print(layer_name)
|
||||
else:
|
||||
layer_name = "/".join(layer.name.split("/")[1:])
|
||||
param_dset = shard_file.create_dataset(
|
||||
"/".join(layer.name.split("/")[1:]), layer.numpy().shape, dtype=layer.numpy().dtype
|
||||
layer_name, layer.numpy().shape, dtype=layer.numpy().dtype
|
||||
)
|
||||
param_dset[:] = layer.numpy()
|
||||
layers.append(layer_name.encode("utf8"))
|
||||
save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
|
||||
|
||||
if push_to_hub:
|
||||
self._upload_modified_files(
|
||||
|
|
|
@ -77,9 +77,11 @@ if is_tf_available():
|
|||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
BertConfig,
|
||||
RagRetriever,
|
||||
TFAutoModel,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFBertModel,
|
||||
TFRagModel,
|
||||
TFSharedEmbeddings,
|
||||
)
|
||||
from transformers.generation_tf_utils import (
|
||||
|
@ -2167,6 +2169,18 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||
},
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_special_layer_name_shardind(self):
|
||||
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
|
||||
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||
ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever)
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
def test_checkpoint_sharding_local(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
|
|
Loading…
Reference in New Issue