clean up pr
This commit is contained in:
parent
f4d79f44c9
commit
fa0c5a2ea1
|
@ -26,14 +26,35 @@ import numpy as np
|
|||
|
||||
from modeling import BertConfig, BertModel
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
def convert(config_path, ckpt_path, out_path=None):
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--bert_config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def convert():
|
||||
# Initialise PyTorch model
|
||||
config = BertConfig.from_json_file(config_path)
|
||||
config = BertConfig.from_json_file(args.bert_config_file)
|
||||
model = BertModel(config)
|
||||
|
||||
# Load weights from TF model
|
||||
path = ckpt_path
|
||||
path = args.tf_checkpoint_path
|
||||
print("Converting TensorFlow checkpoint from {}".format(path))
|
||||
|
||||
init_vars = tf.train.list_variables(path)
|
||||
|
@ -47,17 +68,11 @@ def convert(config_path, ckpt_path, out_path=None):
|
|||
arrays.append(array)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
if not name.startswith("bert"):
|
||||
print("Skipping {}".format(name))
|
||||
continue
|
||||
else:
|
||||
name = name.replace("bert/", "") # skip "bert/"
|
||||
name = name[5:] # skip "bert/"
|
||||
print("Loading {}".format(name))
|
||||
name = name.split('/')
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m":
|
||||
print("Skipping {}".format("/".join(name)))
|
||||
if name[0] in ['redictions', 'eq_relationship']:
|
||||
print("Skipping")
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
|
@ -84,32 +99,7 @@ def convert(config_path, ckpt_path, out_path=None):
|
|||
pointer.data = torch.from_numpy(array)
|
||||
|
||||
# Save pytorch-model
|
||||
if out_path is not None:
|
||||
torch.save(model.state_dict(), out_path)
|
||||
return model
|
||||
|
||||
torch.save(model.state_dict(), args.pytorch_dump_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--bert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the output PyTorch model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
convert(args.bert_config_file, args.tf_checkpoint_path, args.pytorch_dump_path)
|
||||
convert()
|
||||
|
|
Loading…
Reference in New Issue