TransformerXL can't be exported to TorchScript because of control-flow. Exception added to tests.
This commit is contained in:
parent
971c24687f
commit
4703148f0c
|
@ -198,14 +198,17 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di
|
|||
[tester.seq_length, tester.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_commons(tester, config, inputs_dict, test_pruning=True):
|
||||
def create_and_check_commons(tester, config, inputs_dict, test_pruning=True, test_torchscript=True):
|
||||
_create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_torchscript_output_attentions(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_torchscript_output_hidden_state(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict)
|
||||
|
||||
if test_torchscript:
|
||||
_create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_torchscript_output_attentions(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_torchscript_output_hidden_state(tester, tester.all_model_classes, config, inputs_dict)
|
||||
|
||||
if test_pruning:
|
||||
_create_and_check_for_head_pruning(tester, tester.all_model_classes, config, inputs_dict)
|
||||
|
||||
|
|
|
@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase):
|
|||
|
||||
def create_and_check_transfo_xl_commons(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
inputs_dict = {'input_ids': input_ids_1}
|
||||
create_and_check_commons(self, config, inputs_dict, test_pruning=False)
|
||||
create_and_check_commons(self, config, inputs_dict, test_pruning=False, test_torchscript=False)
|
||||
|
||||
def test_default(self):
|
||||
self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self))
|
||||
|
|
Loading…
Reference in New Issue