clean up pr

This commit is contained in:
lukovnikov 2018-11-13 16:24:53 +01:00
parent f4d79f44c9
commit fa0c5a2ea1
1 changed files with 29 additions and 39 deletions

View File

@ -26,14 +26,35 @@ import numpy as np
from modeling import BertConfig, BertModel
parser = argparse.ArgumentParser()
def convert(config_path, ckpt_path, out_path=None):
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
def convert():
# Initialise PyTorch model
config = BertConfig.from_json_file(config_path)
config = BertConfig.from_json_file(args.bert_config_file)
model = BertModel(config)
# Load weights from TF model
path = ckpt_path
path = args.tf_checkpoint_path
print("Converting TensorFlow checkpoint from {}".format(path))
init_vars = tf.train.list_variables(path)
@ -47,17 +68,11 @@ def convert(config_path, ckpt_path, out_path=None):
arrays.append(array)
for name, array in zip(names, arrays):
if not name.startswith("bert"):
print("Skipping {}".format(name))
continue
else:
name = name.replace("bert/", "") # skip "bert/"
name = name[5:] # skip "bert/"
print("Loading {}".format(name))
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m":
print("Skipping {}".format("/".join(name)))
if name[0] in ['redictions', 'eq_relationship']:
print("Skipping")
continue
pointer = model
for m_name in name:
@ -84,32 +99,7 @@ def convert(config_path, ckpt_path, out_path=None):
pointer.data = torch.from_numpy(array)
# Save pytorch-model
if out_path is not None:
torch.save(model.state_dict(), out_path)
return model
torch.save(model.state_dict(), args.pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default=None,
type=str,
required=True,
help="Path the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default=None,
type=str,
required=False,
help="Path to the output PyTorch model.")
args = parser.parse_args()
print(args)
convert(args.bert_config_file, args.tf_checkpoint_path, args.pytorch_dump_path)
convert()