This commit is contained in:
Julien Chaumond 2020-02-26 21:48:49 +00:00
parent f5516805c2
commit b370cc7e99
1 changed files with 4 additions and 2 deletions

View File

@ -205,7 +205,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape).long()
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = self.seq_length // 2
attn_mask[:, half_seq_length:] = 0
@ -222,7 +222,9 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat([attn_mask, torch.ones((attn_mask.shape[0], 1)).long()], dim=1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1
)
# get two different outputs
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)