Add 'with torch.no_grad()' to BertGeneration integration test forward passes (#14963)
This commit is contained in:
parent
d2183a46fb
commit
f71fb5c36e
|
@ -307,7 +307,8 @@ class BertGenerationEncoderIntegrationTest(unittest.TestCase):
|
|||
def test_inference_no_head_absolute_embedding(self):
|
||||
model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
|
||||
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]])
|
||||
output = model(input_ids)[0]
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size([1, 8, 1024])
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
|
@ -322,7 +323,8 @@ class BertGenerationDecoderIntegrationTest(unittest.TestCase):
|
|||
def test_inference_no_head_absolute_embedding(self):
|
||||
model = BertGenerationDecoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
|
||||
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]])
|
||||
output = model(input_ids)[0]
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size([1, 8, 50358])
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
|
|
Loading…
Reference in New Issue