Fix importing unofficial TF models with extra optimizer weights

This commit is contained in:
monologg 2020-01-27 23:39:44 +09:00 committed by Lysandre Debut
parent d7dabfeff5
commit 73368963b2
4 changed files with 19 additions and 4 deletions

View File

@ -117,7 +117,13 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
name = name.split("/")
# Ignore the gradients applied by the LAMB/ADAM optimizers.
if "adam_m" in name or "adam_v" in name or "global_step" in name:
if (
"adam_m" in name
or "adam_v" in name
or "AdamWeightDecayOptimizer" in name
or "AdamWeightDecayOptimizer_1" in name
or "global_step" in name
):
logger.info("Skipping {}".format("/".join(name)))
continue

View File

@ -86,7 +86,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name)))
continue
pointer = model

View File

@ -79,7 +79,10 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
name = txt_name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name)))
tf_weights.pop(txt_name, None)
continue

View File

@ -76,7 +76,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name)))
continue
pointer = model