fix data type (#7513)
This commit is contained in:
parent
62f5ae68ec
commit
bd2621583b
|
@ -238,13 +238,20 @@ class ModuleUtilsMixin:
|
|||
seq_ids = torch.arange(seq_length, device=device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
||||
# causal and attention masks must have same type with pytorch version < 1.3
|
||||
causal_mask = causal_mask.to(attention_mask.dtype)
|
||||
|
||||
if causal_mask.shape[1] < attention_mask.shape[1]:
|
||||
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
||||
causal_mask = torch.cat(
|
||||
[torch.ones((batch_size, seq_length, prefix_seq_len), device=device), causal_mask], axis=-1
|
||||
[
|
||||
torch.ones(
|
||||
(batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
|
||||
),
|
||||
causal_mask,
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
# causal and attention masks must have same type with pytorch version < 1.3
|
||||
causal_mask = causal_mask.to(attention_mask.dtype)
|
||||
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue