TransformerXL can't be exported to TorchScript because of control-flow. Exception added to tests.

This commit is contained in:
LysandreJik 2019-07-03 14:50:23 -04:00
parent 971c24687f
commit 4703148f0c
2 changed files with 8 additions and 5 deletions

View File

@ -198,14 +198,17 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di
[tester.seq_length, tester.hidden_size])
def create_and_check_commons(tester, config, inputs_dict, test_pruning=True):
def create_and_check_commons(tester, config, inputs_dict, test_pruning=True, test_torchscript=True):
_create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_torchscript_output_attentions(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_torchscript_output_hidden_state(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict)
if test_torchscript:
_create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_torchscript_output_attentions(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_torchscript_output_hidden_state(tester, tester.all_model_classes, config, inputs_dict)
if test_pruning:
_create_and_check_for_head_pruning(tester, tester.all_model_classes, config, inputs_dict)

View File

@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase):
def create_and_check_transfo_xl_commons(self, config, input_ids_1, input_ids_2, lm_labels):
inputs_dict = {'input_ids': input_ids_1}
create_and_check_commons(self, config, inputs_dict, test_pruning=False)
create_and_check_commons(self, config, inputs_dict, test_pruning=False, test_torchscript=False)
def test_default(self):
self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self))