fix(phi3): Splits inv_freq calculation in two lines.
This commit is contained in:
parent
2abcd4dec3
commit
aeb6ae7ebe
|
@ -144,10 +144,9 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
||||||
else:
|
else:
|
||||||
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
||||||
|
|
||||||
self.inv_freq = 1.0 / (
|
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
|
||||||
ext_factors
|
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
|
||||||
* self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
|
||||||
)
|
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
|
|
||||||
|
@ -186,10 +185,9 @@ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
||||||
else:
|
else:
|
||||||
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
||||||
|
|
||||||
self.inv_freq = 1.0 / (
|
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
|
||||||
ext_factors
|
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
|
||||||
* self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
|
||||||
)
|
|
||||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue