add tests on TF2.0 & PT checkpoint => model convertion functions
This commit is contained in:
parent
0993586758
commit
898ce064f8
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
|
@ -118,7 +119,7 @@ class TFCommonTestCases:
|
|||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa (architecture similar)
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
|
||||
|
@ -132,6 +133,26 @@ class TFCommonTestCases:
|
|||
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
|
||||
self.assertLessEqual(max_diff, 2e-2)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, 'pt_model.bin')
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, 'tf_model.h5')
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||
pt_model.eval()
|
||||
pt_inputs_dict = dict((name, torch.from_numpy(key.numpy()).to(torch.long))
|
||||
for name, key in inputs_dict.items())
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs_dict)
|
||||
tfo = tf_model(inputs_dict)
|
||||
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
|
||||
self.assertLessEqual(max_diff, 2e-2)
|
||||
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
|
Loading…
Reference in New Issue