Add 'with torch.no_grad()' to BertGeneration integration test forward passes (#14963)

This commit is contained in:
Tavin Turner 2022-01-06 08:39:13 -07:00 committed by GitHub
parent d2183a46fb
commit f71fb5c36e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 2 deletions

View File

@ -307,7 +307,8 @@ class BertGenerationEncoderIntegrationTest(unittest.TestCase):
def test_inference_no_head_absolute_embedding(self):
model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]])
output = model(input_ids)[0]
with torch.no_grad():
output = model(input_ids)[0]
expected_shape = torch.Size([1, 8, 1024])
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
@ -322,7 +323,8 @@ class BertGenerationDecoderIntegrationTest(unittest.TestCase):
def test_inference_no_head_absolute_embedding(self):
model = BertGenerationDecoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]])
output = model(input_ids)[0]
with torch.no_grad():
output = model(input_ids)[0]
expected_shape = torch.Size([1, 8, 50358])
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(