fix RoPE t range issue for fp16 (#26602)

This commit is contained in:
rui-ren 2023-10-06 04:04:54 -07:00 committed by GitHub
parent ea52ed9dc8
commit 87499420bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -108,7 +108,7 @@ class FalconRotaryEmbedding(nn.Module):
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(seq_len, device=device).to(dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
@ -171,7 +171,7 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(seq_len, device=device).to(dtype)
# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
t = t / self.scaling_factor
@ -208,7 +208,7 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(seq_len, device=device).to(dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)