[gpu] Fixup fdd61b1992
This commit is contained in:
parent
f5516805c2
commit
b370cc7e99
|
@ -205,7 +205,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
model.eval()
|
||||
|
||||
# create attention mask
|
||||
attn_mask = torch.ones(input_ids.shape).long()
|
||||
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
half_seq_length = self.seq_length // 2
|
||||
attn_mask[:, half_seq_length:] = 0
|
||||
|
||||
|
@ -222,7 +222,9 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
# append to next input_ids and attn_mask
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
attn_mask = torch.cat([attn_mask, torch.ones((attn_mask.shape[0], 1)).long()], dim=1)
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)
|
||||
|
|
Loading…
Reference in New Issue