Add MultiBERTs conversion script (#13077)

* Init multibert checkpoint conversion script

* Rename conversion script

* Fix MultiBerts Conversion Script

* Apply suggestions from code review

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
Gunjan Chhablani 2021-09-30 22:18:56 +05:30 committed by GitHub
parent e1d1c7c087
commit 9a9805fccf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 128 additions and 0 deletions

View File

@ -0,0 +1,128 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script can be used to convert a head-less TF 2.x MultiBERTs model to PyTorch, as published on the official GitHub:
https://github.com/tensorflow/models/tree/master/official/nlp/bert
"""
import argparse
import os
import tensorflow as tf
import torch
from transformers import BertConfig, BertForPreTraining
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def convert_multibert_checkpoint_to_pytorch(tf_checkpoint_path, config_path, save_path):
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
config = BertConfig.from_pretrained(config_path)
model = BertForPreTraining(config)
layer_nums = []
for full_name, shape in init_vars:
array = tf.train.load_variable(tf_path, full_name)
names.append(full_name)
split_names = full_name.split("/")
for name in split_names:
if name.startswith("layer_"):
layer_nums.append(int(name.split("_")[-1]))
arrays.append(array)
logger.info(f"Read a total of {len(arrays):,} layers")
name_to_array = dict(zip(names, arrays))
# Check that number of layers match
assert config.num_hidden_layers == len(list(set(layer_nums)))
state_dict = model.state_dict()
# Need to do this explicitly as it is a buffer
position_ids = state_dict["bert.embeddings.position_ids"]
new_state_dict = {"bert.embeddings.position_ids": position_ids}
# Encoder Layers
for weight_name in names:
pt_weight_name = weight_name.replace("kernel", "weight").replace("gamma", "weight").replace("beta", "bias")
name_split = pt_weight_name.split("/")
for name_idx, name in enumerate(name_split):
if name.startswith("layer_"):
name_split[name_idx] = name.replace("_", ".")
if name_split[-1].endswith("embeddings"):
name_split.append("weight")
if name_split[0] == "cls":
if name_split[-1] == "output_bias":
name_split[-1] = "bias"
if name_split[-1] == "output_weights":
name_split[-1] = "weight"
if name_split[-1] == "weight" and name_split[-2] == "dense":
name_to_array[weight_name] = name_to_array[weight_name].T
pt_weight_name = ".".join(name_split)
new_state_dict[pt_weight_name] = torch.from_numpy(name_to_array[weight_name])
new_state_dict["cls.predictions.decoder.weight"] = new_state_dict["bert.embeddings.word_embeddings.weight"].clone()
new_state_dict["cls.predictions.decoder.bias"] = new_state_dict["cls.predictions.bias"].clone().T
# Load State Dict
model.load_state_dict(new_state_dict)
# Save PreTrained
logger.info(f"Saving pretrained model to {save_path}")
model.save_pretrained(save_path)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tf_checkpoint_path",
type=str,
default="./seed_0/bert.ckpt",
required=False,
help="Path to the TensorFlow 2.x checkpoint path.",
)
parser.add_argument(
"--bert_config_file",
type=str,
default="./bert_config.json",
required=False,
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
)
parser.add_argument(
"--save_path",
type=str,
required=True,
help="Path to the output PyTorch model (must include filename).",
)
args = parser.parse_args()
convert_multibert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.save_path)