added conversion script
This commit is contained in:
parent
90d360a7a9
commit
c5d532e5f6
|
@ -0,0 +1,82 @@
|
|||
# coding=utf-8
|
||||
"""Convert BERT checkpoint."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from .modeling_pytorch import BertConfig, BertModel
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--bert_config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def convert():
|
||||
# Load weights from TF model
|
||||
path = args.tf_checkpoint_path
|
||||
print("Converting TensorFlow checkpoint from {}".format(path))
|
||||
|
||||
init_vars = tf.train.list_variables(path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
print("Loading {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(path, name)
|
||||
print("Numpy array shape {}".format(array.shape))
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
|
||||
# Initialise PyTorch model and fill weights-in
|
||||
config = BertConfig.from_json_file(args.bert_config_file)
|
||||
model = BertModel(config)
|
||||
for name, array in zip(names, arrays):
|
||||
name = name[5:] # skip "bert/"
|
||||
assert name[-2:] == ":0"
|
||||
name = name[:-2]
|
||||
name = name.split('/')
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
||||
l = re.split(r'(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
pointer.data = torch.from_numpy(array)
|
||||
|
||||
# Save pytorch-model
|
||||
torch.save(model.state_dict(), args.pytorch_dump_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
||||
return None
|
|
@ -119,7 +119,7 @@ class BERTLayerNorm(nn.Module):
|
|||
self.variance_epsilon = variance_epsilon
|
||||
|
||||
def forward(self, x):
|
||||
# TODO check it's identical to TF implementation in details
|
||||
# TODO check it's identical to TF implementation in details (epsilon and axes)
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
|
@ -128,9 +128,7 @@ class BERTLayerNorm(nn.Module):
|
|||
# inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
|
||||
|
||||
class BERTEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_size, vocab_size,
|
||||
token_type_vocab_size, max_position_embeddings,
|
||||
config):
|
||||
def __init__(self, config):
|
||||
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size)
|
||||
|
||||
|
@ -323,27 +321,32 @@ class BERTEncoder(nn.Module):
|
|||
Return:
|
||||
float Tensor of shape [batch_size, seq_length, hidden_size]
|
||||
"""
|
||||
all_encoder_layers = []
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, attention_mask)
|
||||
return hidden_states
|
||||
all_encoder_layers.append(hidden_states)
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
class BERTPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTPooler, self).__init__()
|
||||
layer = BERTLayer(n_ctx, cfg, scale=True)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
def forward(self, hidden_states):
|
||||
"""
|
||||
Args:
|
||||
hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size]
|
||||
Return:
|
||||
float Tensor of shape [batch_size, seq_length, hidden_size]
|
||||
float Tensor of shape [batch_size, hidden_size]
|
||||
"""
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, attention_mask)
|
||||
return hidden_states
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token. We assume that this has been pre-trained
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertModel(nn.Module):
|
||||
|
@ -381,14 +384,6 @@ class BertModel(nn.Module):
|
|||
is invalid.
|
||||
"""
|
||||
super(BertModel).__init__()
|
||||
config = copy.deepcopy(config)
|
||||
if not is_training:
|
||||
config.hidden_dropout_prob = 0.0
|
||||
config.attention_probs_dropout_prob = 0.0
|
||||
|
||||
batch_size = input_ids.size(0)
|
||||
seq_length = input_ids.size(1)
|
||||
|
||||
self.embeddings = BERTEmbeddings(config)
|
||||
self.encoder = BERTEncoder(config)
|
||||
self.pooler = BERTPooler(config)
|
||||
|
@ -396,4 +391,6 @@ class BertModel(nn.Module):
|
|||
def forward(self, input_ids, token_type_ids, attention_mask):
|
||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||
all_encoder_layers = self.encoder(embedding_output, attention_mask)
|
||||
return all_encoder_layers
|
||||
sequence_output = all_encoder_layers[-1]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
return all_encoder_layers, pooled_output
|
||||
|
|
Loading…
Reference in New Issue