add tests on TF2.0 & PT checkpoint => model convertion functions

This commit is contained in:
thomwolf 2019-10-15 10:04:19 +02:00
parent 0993586758
commit 898ce064f8
1 changed files with 22 additions and 1 deletions

View File

@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import absolute_import, division, print_function
import os
import copy
import json
import logging
@ -118,7 +119,7 @@ class TFCommonTestCases:
tf_model = model_class(config)
pt_model = pt_model_class(config)
# Check we can load pt model in tf and vice-versa (architecture similar)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
@ -132,6 +133,26 @@ class TFCommonTestCases:
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
self.assertLessEqual(max_diff, 2e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, 'pt_model.bin')
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, 'tf_model.h5')
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
pt_inputs_dict = dict((name, torch.from_numpy(key.numpy()).to(torch.long))
for name, key in inputs_dict.items())
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(inputs_dict)
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
self.assertLessEqual(max_diff, 2e-2)
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()