Fix importing unofficial TF models with extra optimizer weights
This commit is contained in:
parent
d7dabfeff5
commit
73368963b2
|
@ -117,7 +117,13 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
|||
name = name.split("/")
|
||||
|
||||
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
||||
if "adam_m" in name or "adam_v" in name or "global_step" in name:
|
||||
if (
|
||||
"adam_m" in name
|
||||
or "adam_v" in name
|
||||
or "AdamWeightDecayOptimizer" in name
|
||||
or "AdamWeightDecayOptimizer_1" in name
|
||||
or "global_step" in name
|
||||
):
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
|
||||
|
|
|
@ -86,7 +86,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||
name = name.split("/")
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||
if any(
|
||||
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||
for n in name
|
||||
):
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
pointer = model
|
||||
|
|
|
@ -79,7 +79,10 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
|||
name = txt_name.split("/")
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||
if any(
|
||||
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||
for n in name
|
||||
):
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
tf_weights.pop(txt_name, None)
|
||||
continue
|
||||
|
|
|
@ -76,7 +76,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
|||
name = name.split("/")
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||
if any(
|
||||
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||
for n in name
|
||||
):
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
pointer = model
|
||||
|
|
Loading…
Reference in New Issue