probably ok weights convertion script

This commit is contained in:
thomwolf 2018-11-01 19:12:31 +01:00
parent ab0e8932a8
commit 960ef4df3b
1 changed files with 7 additions and 2 deletions

View File

@ -9,6 +9,7 @@ import re
import argparse
import tensorflow as tf
import torch
import numpy as np
from modeling_pytorch import BertConfig, BertModel
@ -55,7 +56,11 @@ def convert():
for name, array in zip(names, arrays):
name = name[5:] # skip "bert/"
print("Loading {}".format(name))
name = name.split('/')
if name[0] in ['redictions', 'eq_relationship']:
print("Skipping")
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
@ -71,8 +76,8 @@ def convert():
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
# elif m_name == 'kernel':
# pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e: