convert all models

This commit is contained in:
thomwolf 2019-09-12 13:14:07 +02:00
parent 969d3ae95e
commit a84adddd1b
2 changed files with 1189 additions and 32 deletions

View File

@ -18,10 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import tensorflow as tf
from pytorch_transformers import is_torch_available
from pytorch_transformers import is_torch_available, cached_path
from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2,
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2,
@ -31,26 +32,36 @@ from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt
if is_torch_available():
import torch
import numpy as np
from pytorch_transformers import BertForPreTraining, GPT2LMHeadModel, XLNetLMHeadModel, XLMWithLMHeadModel
from pytorch_transformers import (BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,)
else:
BertForPreTraining, GPT2LMHeadModel = None, None
(BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,) = (
None, None, None,
None, None, None,
None, None, None,
None, None, None,)
import logging
logging.basicConfig(level=logging.INFO)
MODEL_CLASSES = {
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining),
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel),
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel),
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel),
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
}
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False):
if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
config_class, model_class, loading_fct, pt_model_class = MODEL_CLASSES[model_type]
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
# Initialise TF model
config = config_class.from_json_file(config_file)
@ -68,8 +79,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tfo = tf_model(tf_inputs, training=False) # build the network
pt_model = pt_model_class.from_pretrained(None,
config=config,
state_dict=torch.load(pytorch_checkpoint_path,
config=config,
state_dict=torch.load(pytorch_checkpoint_path,
map_location='cpu'))
pt_inputs = torch.tensor(inputs_list)
with torch.no_grad():
@ -79,42 +90,80 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
np_tf = tfo[0].numpy()
diff = np.amax(np.abs(np_pt - np_tf))
print("Max absolute difference between models outputs {}".format(diff))
assert diff <= 1e-3, "Error, model absolute difference is >1e-3"
# Save pytorch-model
print("Save TensorFlow model to {}".format(tf_dump_path))
tf_model.save_weights(tf_dump_path)
tf_model.save_weights(tf_dump_path, save_format='h5')
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with_pt_model=False):
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
if args_model_type is None:
model_types = list(MODEL_CLASSES.keys())
else:
model_types = [args_model_type]
for j, model_type in enumerate(model_types, start=1):
print("=" * 100)
print(" Converting model type {}/{}: {}".format(j, len(model_types), model_type))
print("=" * 100)
if model_type not in MODEL_CLASSES:
raise ValueError("Unrecognized model type {}, should be one of {}.".format(model_type, list(MODEL_CLASSES.keys())))
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
for i, shortcut_name in enumerate(aws_config_map.keys(), start=1):
print("-" * 100)
print(" Converting checkpoint {}/{}: {}".format(i, len(aws_config_map), shortcut_name))
print("-" * 100)
config_file = cached_path(aws_config_map[shortcut_name], force_download=True)
model_file = cached_path(aws_model_maps[shortcut_name], force_download=True)
convert_pt_checkpoint_to_tf(model_type,
model_file,
config_file,
os.path.join(tf_dump_path, shortcut_name + '-tf_model.h5'),
compare_with_pt_model=compare_with_pt_model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--model_type",
default = None,
type = str,
required = True,
help = "Model type selcted in the list of {}.".format(list(MODEL_CLASSES.keys())))
parser.add_argument("--pytorch_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path to the PyTorch checkpoint path.")
parser.add_argument("--config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture.")
parser.add_argument("--tf_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output Tensorflow dump file.")
parser.add_argument("--model_type",
default = None,
type = str,
help = "Model type selected in the list of {}. If not given, will download and convert all the models from AWS.".format(list(MODEL_CLASSES.keys())))
parser.add_argument("--pytorch_checkpoint_path",
default = None,
type = str,
help = "Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
"If not given, will download and convert all the checkpoints from AWS.")
parser.add_argument("--config_file",
default = None,
type = str,
help = "The config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture. If not given and "
"--pytorch_checkpoint_path is not given or is a shortcut name"
"use the configuration associated to teh shortcut name on the AWS")
parser.add_argument("--compare_with_pt_model",
action='store_true',
help = "Compare Tensorflow and PyTorch model predictions.")
args = parser.parse_args()
convert_pt_checkpoint_to_tf(args.model_type.lower(),
args.pytorch_checkpoint_path,
args.config_file,
args.tf_dump_path,
compare_with_pt_model=args.compare_with_pt_model)
if args.pytorch_checkpoint_path is not None:
convert_pt_checkpoint_to_tf(args.model_type.lower(),
args.pytorch_checkpoint_path,
args.config_file,
args.tf_dump_path,
compare_with_pt_model=args.compare_with_pt_model)
else:
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
args.tf_dump_path,
compare_with_pt_model=args.compare_with_pt_model)

File diff suppressed because it is too large Load Diff