From 4703148f0cdf5308d707e95ce285d01bf4e8ccfd Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 3 Jul 2019 14:50:23 -0400 Subject: [PATCH] TransformerXL can't be exported to TorchScript because of control-flow. Exception added to tests. --- pytorch_pretrained_bert/tests/model_tests_commons.py | 11 +++++++---- .../tests/modeling_transfo_xl_test.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pytorch_pretrained_bert/tests/model_tests_commons.py b/pytorch_pretrained_bert/tests/model_tests_commons.py index 75c0ae19fd..0afda5f2ce 100644 --- a/pytorch_pretrained_bert/tests/model_tests_commons.py +++ b/pytorch_pretrained_bert/tests/model_tests_commons.py @@ -198,14 +198,17 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di [tester.seq_length, tester.hidden_size]) -def create_and_check_commons(tester, config, inputs_dict, test_pruning=True): +def create_and_check_commons(tester, config, inputs_dict, test_pruning=True, test_torchscript=True): _create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict) - _create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict) - _create_and_check_torchscript_output_attentions(tester, tester.all_model_classes, config, inputs_dict) - _create_and_check_torchscript_output_hidden_state(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict) + + if test_torchscript: + _create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict) + _create_and_check_torchscript_output_attentions(tester, tester.all_model_classes, config, inputs_dict) + _create_and_check_torchscript_output_hidden_state(tester, tester.all_model_classes, config, inputs_dict) + if test_pruning: _create_and_check_for_head_pruning(tester, tester.all_model_classes, config, inputs_dict) diff --git a/pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py b/pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py index 8b46b6d755..caeb25b412 100644 --- a/pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py +++ b/pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py @@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase): def create_and_check_transfo_xl_commons(self, config, input_ids_1, input_ids_2, lm_labels): inputs_dict = {'input_ids': input_ids_1} - create_and_check_commons(self, config, inputs_dict, test_pruning=False) + create_and_check_commons(self, config, inputs_dict, test_pruning=False, test_torchscript=False) def test_default(self): self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self))