XLMWithLMHead fixed - standardize conversion

This commit is contained in:
thomwolf 2019-09-11 15:47:33 +02:00
parent 646711e1e2
commit 969d3ae95e
6 changed files with 20 additions and 21 deletions

View File

@ -57,7 +57,7 @@ def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def gelu(x):

View File

@ -46,7 +46,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def gelu(x):

View File

@ -19,34 +19,34 @@ from __future__ import (absolute_import, division, print_function,
unicode_literals)
import logging
from pytorch_transformers import is_tf_available, is_torch_available
import os
logger = logging.getLogger(__name__)
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path):
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None):
""" Load pytorch checkpoints in a TF 2.0 model
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
"""
if not is_tf_available() or not is_torch_available():
try:
import tensorflow as tf
import torch
except ImportError as e:
logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
raise ImportError
import torch
raise e
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info("Loading PyTorch weights from {}".format(pt_path))
pt_state_dict = torch.load(pt_path, map_location='cpu')
return load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict)
return load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs)
def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict):
def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
""" Load pytorch state_dict in a TF 2.0 model.
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
@ -102,7 +102,8 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict):
K.batch_set_value(weight_value_tuples)
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
if tf_inputs is not None:
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))

View File

@ -50,11 +50,9 @@ def load_xlm_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
attns_list = [[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]
langs_list = [[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]
tf_inputs = tf.constant(inputs_list)
tf_attns = tf.constant(attns_list)
tf_langs = tf.constant(langs_list)
tfo = tf_model([tf_inputs, tf_attns, tf_langs], training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
tf_inputs = [tf.constant(inputs_list), tf.constant(attns_list), tf.constant(langs_list)]
tfo = tf_model(tf_inputs, training=False)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def create_sinusoidal_embeddings(n_pos, dim, out):
@ -614,7 +612,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLMWithLMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name='transformer___')
self.transformer = TFXLMMainLayer(config, name='transformer')
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')

View File

@ -45,7 +45,7 @@ def load_xlnet_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf.constant(inputs_list)
tfo = tf_model(tf_inputs, training=False) # build the network
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
def gelu(x):

View File

@ -563,10 +563,10 @@ class XLMPredLayer(nn.Module):
"""
outputs = ()
if self.asm is False:
scores = self.proj(x).view(-1, self.n_words)
scores = self.proj(x)
outputs = (scores,) + outputs
if y is not None:
loss = F.cross_entropy(scores, y, reduction='elementwise_mean')
loss = F.cross_entropy(scores.view(-1, self.n_words), y, reduction='elementwise_mean')
outputs = (loss,) + outputs
else:
scores = self.proj.log_prob(x)