Fix torch.ones usage in xlnet (#28471)
Fix xlnet torch.ones usage Co-authored-by: sungho-ham <sungho.ham@linecorp.com>
This commit is contained in:
parent
c45ef1c0d1
commit
edb314ae2b
|
@ -976,7 +976,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
|
||||
|
||||
"""
|
||||
mask = torch.ones(qlen, qlen + mlen, self.device)
|
||||
mask = torch.ones((qlen, qlen + mlen), device=self.device)
|
||||
if self.same_length:
|
||||
mask_lo = mask[:, :qlen].tril(-1)
|
||||
mask.triu_(mlen + 1)
|
||||
|
|
Loading…
Reference in New Issue