fix RoPE t range issue for fp16 (#26602)
This commit is contained in:
parent
ea52ed9dc8
commit
87499420bf
|
@ -108,7 +108,7 @@ class FalconRotaryEmbedding(nn.Module):
|
|||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.seq_len_cached = seq_len
|
||||
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||
t = torch.arange(seq_len, device=device).to(dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||
|
||||
|
@ -171,7 +171,7 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
|||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.seq_len_cached = seq_len
|
||||
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||
t = torch.arange(seq_len, device=device).to(dtype)
|
||||
# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
|
||||
t = t / self.scaling_factor
|
||||
|
||||
|
@ -208,7 +208,7 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
|
|||
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||
t = torch.arange(seq_len, device=device).to(dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue