Add 'with torch.no_grad()' to integration test forward pass (#14820)
This commit is contained in:
parent
b37cf7dee4
commit
0940e9b242
|
@ -587,7 +587,8 @@ class BertModelIntegrationTest(unittest.TestCase):
|
|||
model = BertModel.from_pretrained("bert-base-uncased")
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
|
||||
attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
|
||||
output = model(input_ids, attention_mask=attention_mask)[0]
|
||||
with torch.no_grad():
|
||||
output = model(input_ids, attention_mask=attention_mask)[0]
|
||||
expected_shape = torch.Size((1, 11, 768))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor([[[0.4249, 0.1008, 0.7531], [0.3771, 0.1188, 0.7467], [0.4152, 0.1098, 0.7108]]])
|
||||
|
@ -599,7 +600,8 @@ class BertModelIntegrationTest(unittest.TestCase):
|
|||
model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key")
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
|
||||
attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
|
||||
output = model(input_ids, attention_mask=attention_mask)[0]
|
||||
with torch.no_grad():
|
||||
output = model(input_ids, attention_mask=attention_mask)[0]
|
||||
expected_shape = torch.Size((1, 11, 768))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
|
@ -613,7 +615,8 @@ class BertModelIntegrationTest(unittest.TestCase):
|
|||
model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key-query")
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
|
||||
attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
|
||||
output = model(input_ids, attention_mask=attention_mask)[0]
|
||||
with torch.no_grad():
|
||||
output = model(input_ids, attention_mask=attention_mask)[0]
|
||||
expected_shape = torch.Size((1, 11, 768))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
|
|
Loading…
Reference in New Issue