probably ok weights convertion script
This commit is contained in:
parent
ab0e8932a8
commit
960ef4df3b
|
@ -9,6 +9,7 @@ import re
|
|||
import argparse
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from modeling_pytorch import BertConfig, BertModel
|
||||
|
||||
|
@ -55,7 +56,11 @@ def convert():
|
|||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name[5:] # skip "bert/"
|
||||
print("Loading {}".format(name))
|
||||
name = name.split('/')
|
||||
if name[0] in ['redictions', 'eq_relationship']:
|
||||
print("Skipping")
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
|
@ -71,8 +76,8 @@ def convert():
|
|||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
# elif m_name == 'kernel':
|
||||
# pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
|
|
Loading…
Reference in New Issue