XLMWithLMHead fixed - standardize conversion
This commit is contained in:
parent
646711e1e2
commit
969d3ae95e
|
@ -57,7 +57,7 @@ def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
|
|
|
@ -46,7 +46,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
|
|
|
@ -19,34 +19,34 @@ from __future__ import (absolute_import, division, print_function,
|
|||
unicode_literals)
|
||||
|
||||
import logging
|
||||
|
||||
from pytorch_transformers import is_tf_available, is_torch_available
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path):
|
||||
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None):
|
||||
""" Load pytorch checkpoints in a TF 2.0 model
|
||||
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
||||
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
||||
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
||||
"""
|
||||
if not is_tf_available() or not is_torch_available():
|
||||
try:
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
except ImportError as e:
|
||||
logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise ImportError
|
||||
|
||||
import torch
|
||||
raise e
|
||||
|
||||
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
||||
logger.info("Loading PyTorch weights from {}".format(pt_path))
|
||||
|
||||
pt_state_dict = torch.load(pt_path, map_location='cpu')
|
||||
|
||||
return load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict)
|
||||
return load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict):
|
||||
def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
|
||||
""" Load pytorch state_dict in a TF 2.0 model.
|
||||
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
||||
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
||||
|
@ -102,7 +102,8 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict):
|
|||
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
|
||||
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
|
||||
if tf_inputs is not None:
|
||||
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
|
||||
|
||||
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
|
||||
|
||||
|
|
|
@ -50,11 +50,9 @@ def load_xlm_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
attns_list = [[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]
|
||||
langs_list = [[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tf_attns = tf.constant(attns_list)
|
||||
tf_langs = tf.constant(langs_list)
|
||||
tfo = tf_model([tf_inputs, tf_attns, tf_langs], training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
||||
tf_inputs = [tf.constant(inputs_list), tf.constant(attns_list), tf.constant(langs_list)]
|
||||
tfo = tf_model(tf_inputs, training=False)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def create_sinusoidal_embeddings(n_pos, dim, out):
|
||||
|
@ -614,7 +612,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
|
|||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXLMWithLMHeadModel, self).__init__(config, *inputs, **kwargs)
|
||||
self.transformer = TFXLMMainLayer(config, name='transformer___')
|
||||
self.transformer = TFXLMMainLayer(config, name='transformer')
|
||||
self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name='pred_layer_._proj')
|
||||
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ def load_xlnet_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
|
|||
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False) # build the network
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
||||
return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
|
|
|
@ -563,10 +563,10 @@ class XLMPredLayer(nn.Module):
|
|||
"""
|
||||
outputs = ()
|
||||
if self.asm is False:
|
||||
scores = self.proj(x).view(-1, self.n_words)
|
||||
scores = self.proj(x)
|
||||
outputs = (scores,) + outputs
|
||||
if y is not None:
|
||||
loss = F.cross_entropy(scores, y, reduction='elementwise_mean')
|
||||
loss = F.cross_entropy(scores.view(-1, self.n_words), y, reduction='elementwise_mean')
|
||||
outputs = (loss,) + outputs
|
||||
else:
|
||||
scores = self.proj.log_prob(x)
|
||||
|
|
Loading…
Reference in New Issue