From 73368963b200f2d70d2267bd49a3fa794850b3ff Mon Sep 17 00:00:00 2001 From: monologg Date: Mon, 27 Jan 2020 23:39:44 +0900 Subject: [PATCH] Fix importing unofficial TF models with extra optimizer weights --- src/transformers/modeling_albert.py | 8 +++++++- src/transformers/modeling_bert.py | 5 ++++- src/transformers/modeling_t5.py | 5 ++++- templates/adding_a_new_model/modeling_xxx.py | 5 ++++- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_albert.py b/src/transformers/modeling_albert.py index d2a5d4878e..c7bf464f26 100644 --- a/src/transformers/modeling_albert.py +++ b/src/transformers/modeling_albert.py @@ -117,7 +117,13 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path): name = name.split("/") # Ignore the gradients applied by the LAMB/ADAM optimizers. - if "adam_m" in name or "adam_v" in name or "global_step" in name: + if ( + "adam_m" in name + or "adam_v" in name + or "AdamWeightDecayOptimizer" in name + or "AdamWeightDecayOptimizer_1" in name + or "global_step" in name + ): logger.info("Skipping {}".format("/".join(name))) continue diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 62e752b89c..5c032f05e5 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -86,7 +86,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): name = name.split("/") # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model - if any(n in ["adam_v", "adam_m", "global_step"] for n in name): + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): logger.info("Skipping {}".format("/".join(name))) continue pointer = model diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 405ebe5667..2c8e7d8273 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -79,7 +79,10 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): name = txt_name.split("/") # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model - if any(n in ["adam_v", "adam_m", "global_step"] for n in name): + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): logger.info("Skipping {}".format("/".join(name))) tf_weights.pop(txt_name, None) continue diff --git a/templates/adding_a_new_model/modeling_xxx.py b/templates/adding_a_new_model/modeling_xxx.py index f9f4daa950..a92f3cbe55 100644 --- a/templates/adding_a_new_model/modeling_xxx.py +++ b/templates/adding_a_new_model/modeling_xxx.py @@ -76,7 +76,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path): name = name.split("/") # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model - if any(n in ["adam_v", "adam_m", "global_step"] for n in name): + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): logger.info("Skipping {}".format("/".join(name))) continue pointer = model