special edition script
This commit is contained in:
parent
25f73add07
commit
04287a4d68
|
@ -0,0 +1,99 @@
|
|||
# coding=utf-8
|
||||
"""Convert BERT checkpoint."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from modeling_pytorch import BertConfig, BertForSequenceClassification
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--bert_config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def convert():
|
||||
# Initialise PyTorch model
|
||||
config = BertConfig.from_json_file(args.bert_config_file)
|
||||
model = BertForSequenceClassification(config, num_labels=2)
|
||||
|
||||
# Load weights from TF model
|
||||
path = args.tf_checkpoint_path
|
||||
print("Converting TensorFlow checkpoint from {}".format(path))
|
||||
|
||||
init_vars = tf.train.list_variables(path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
print("Loading {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(path, name)
|
||||
print("Numpy array shape {}".format(array.shape))
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
# name = name[5:] # skip "bert/"
|
||||
print("Loading {} or shape {}".format(name, array.shape))
|
||||
name = name.split('/')
|
||||
if name[0] in ['cls']:
|
||||
if name[1] in ['predictions']:
|
||||
print("Skipping")
|
||||
continue
|
||||
elif name[1] in ['seq_relationship']:
|
||||
name = name[2:]
|
||||
assert len(name) == 1
|
||||
name[0] = name[0][7:]
|
||||
pointer = model.classifier
|
||||
else:
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] in ['kernel', 'weights']:
|
||||
pointer = getattr(pointer, 'weight')
|
||||
else:
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
pointer.data = torch.from_numpy(array)
|
||||
|
||||
# Save pytorch-model
|
||||
torch.save(model.state_dict(), args.pytorch_dump_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
|
@ -482,9 +482,14 @@ class BertForQuestionAnswering(nn.Module):
|
|||
def init_weights(m):
|
||||
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
print("Initializing {}".format(m))
|
||||
# Slight difference here with the TF version which uses truncated_normal
|
||||
# Slight difference here with the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
m.weight.data.normal_(config.initializer_range)
|
||||
elif isinstance(m, BERTLayerNorm):
|
||||
m.beta.data.normal_(config.initializer_range)
|
||||
m.gamme.data.normal_(config.initializer_range)
|
||||
if isinstance(m, nn.Linear):
|
||||
m.bias.data.zero_()
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):
|
||||
|
|
|
@ -480,9 +480,9 @@ def main():
|
|||
|
||||
model = BertForSequenceClassification(bert_config, len(label_list))
|
||||
if args.init_checkpoint is not None:
|
||||
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||
model.to(device)
|
||||
|
||||
|
||||
if n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
|
@ -575,7 +575,7 @@ def main():
|
|||
|
||||
eval_loss += tmp_eval_loss.item()
|
||||
eval_accuracy += tmp_eval_accuracy
|
||||
|
||||
|
||||
nb_eval_examples += input_ids.size(0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_examples #len(eval_dataloader)
|
||||
|
|
Loading…
Reference in New Issue