From c5d532e5f61c2807970a7edb9a34fe7ec6dc96b8 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 1 Nov 2018 17:40:05 +0100 Subject: [PATCH] added conversion script --- convert_tf_checkpoint.py | 82 ++++++++++++++++++++++++++++++++++++++++ modeling_pytorch.py | 39 +++++++++---------- 2 files changed, 100 insertions(+), 21 deletions(-) create mode 100644 convert_tf_checkpoint.py diff --git a/convert_tf_checkpoint.py b/convert_tf_checkpoint.py new file mode 100644 index 0000000000..bd8ddd754f --- /dev/null +++ b/convert_tf_checkpoint.py @@ -0,0 +1,82 @@ +# coding=utf-8 +"""Convert BERT checkpoint.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import argparse +import tensorflow as tf +import torch + +from .modeling_pytorch import BertConfig, BertModel + +parser = argparse.ArgumentParser() + +## Required parameters +parser.add_argument("--tf_checkpoint_path", + default = None, + type = str, + required = True, + help = "Path the TensorFlow checkpoint path.") +parser.add_argument("--bert_config_file", + default = None, + type = str, + required = True, + help = "The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture.") +parser.add_argument("--pytorch_dump_path", + default = None, + type = str, + required = True, + help = "Path to the output PyTorch model.") + +args = parser.parse_args() + +def convert(): + # Load weights from TF model + path = args.tf_checkpoint_path + print("Converting TensorFlow checkpoint from {}".format(path)) + + init_vars = tf.train.list_variables(path) + names = [] + arrays = [] + for name, shape in init_vars: + print("Loading {} with shape {}".format(name, shape)) + array = tf.train.load_variable(path, name) + print("Numpy array shape {}".format(array.shape)) + names.append(name) + arrays.append(array) + + # Initialise PyTorch model and fill weights-in + config = BertConfig.from_json_file(args.bert_config_file) + model = BertModel(config) + for name, array in zip(names, arrays): + name = name[5:] # skip "bert/" + assert name[-2:] == ":0" + name = name[:-2] + name = name.split('/') + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+\d+', m_name): + l = re.split(r'(\d+)', m_name) + else: + l = [m_name] + pointer = getattr(pointer, l[0]) + if len(l) >= 2: + num = int(l[1]) + pointer = pointer[num] + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + pointer.data = torch.from_numpy(array) + + # Save pytorch-model + torch.save(model.state_dict(), args.pytorch_dump_path) + +if __name__ == "__main__": + convert() + return None diff --git a/modeling_pytorch.py b/modeling_pytorch.py index f6ca69665c..f3ae8ce77f 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -119,7 +119,7 @@ class BERTLayerNorm(nn.Module): self.variance_epsilon = variance_epsilon def forward(self, x): - # TODO check it's identical to TF implementation in details + # TODO check it's identical to TF implementation in details (epsilon and axes) u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) @@ -128,9 +128,7 @@ class BERTLayerNorm(nn.Module): # inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) class BERTEmbeddings(nn.Module): - def __init__(self, embedding_size, vocab_size, - token_type_vocab_size, max_position_embeddings, - config): + def __init__(self, config): self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size) @@ -323,27 +321,32 @@ class BERTEncoder(nn.Module): Return: float Tensor of shape [batch_size, seq_length, hidden_size] """ + all_encoder_layers = [] for layer_module in self.layer: hidden_states = layer_module(hidden_states, attention_mask) - return hidden_states + all_encoder_layers.append(hidden_states) + return all_encoder_layers class BERTPooler(nn.Module): def __init__(self, config): super(BERTPooler, self).__init__() - layer = BERTLayer(n_ctx, cfg, scale=True) - self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() - def forward(self, hidden_states, attention_mask): + def forward(self, hidden_states): """ Args: hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size] Return: - float Tensor of shape [batch_size, seq_length, hidden_size] + float Tensor of shape [batch_size, hidden_size] """ - for layer_module in self.layer: - hidden_states = layer_module(hidden_states, attention_mask) - return hidden_states + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. We assume that this has been pre-trained + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output class BertModel(nn.Module): @@ -381,14 +384,6 @@ class BertModel(nn.Module): is invalid. """ super(BertModel).__init__() - config = copy.deepcopy(config) - if not is_training: - config.hidden_dropout_prob = 0.0 - config.attention_probs_dropout_prob = 0.0 - - batch_size = input_ids.size(0) - seq_length = input_ids.size(1) - self.embeddings = BERTEmbeddings(config) self.encoder = BERTEncoder(config) self.pooler = BERTPooler(config) @@ -396,4 +391,6 @@ class BertModel(nn.Module): def forward(self, input_ids, token_type_ids, attention_mask): embedding_output = self.embeddings(input_ids, token_type_ids) all_encoder_layers = self.encoder(embedding_output, attention_mask) - return all_encoder_layers + sequence_output = all_encoder_layers[-1] + pooled_output = self.pooler(sequence_output) + return all_encoder_layers, pooled_output