Fix torch.ones usage in xlnet (#28471)

Fix xlnet torch.ones usage

Co-authored-by: sungho-ham <sungho.ham@linecorp.com>
This commit is contained in:
sungho-ham 2024-01-12 23:31:00 +09:00 committed by GitHub
parent c45ef1c0d1
commit edb314ae2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -976,7 +976,7 @@ class XLNetModel(XLNetPreTrainedModel):
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
mask = torch.ones(qlen, qlen + mlen, self.device)
mask = torch.ones((qlen, qlen + mlen), device=self.device)
if self.same_length:
mask_lo = mask[:, :qlen].tril(-1)
mask.triu_(mlen + 1)