fix test
This commit is contained in:
parent
0f9fc4fbde
commit
1f5d9513d8
|
@ -161,7 +161,8 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
|||
"outputs": outputs.numpy(),
|
||||
}
|
||||
|
||||
model.config.mem_len = 0
|
||||
config.mem_len = 0
|
||||
model = TFXLNetModel(config)
|
||||
no_mems_outputs = model(inputs)
|
||||
self.parent.assertEqual(len(no_mems_outputs), 1)
|
||||
|
||||
|
|
|
@ -150,7 +150,9 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||
"outputs": outputs,
|
||||
}
|
||||
|
||||
model.config.mem_len = 0
|
||||
config.mem_len = 0
|
||||
model = XLNetModel(config)
|
||||
model.eval()
|
||||
no_mems_outputs = model(input_ids_1)
|
||||
self.parent.assertEqual(len(no_mems_outputs), 1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue