fix(phi3): Splits inv_freq calculation in two lines.

This commit is contained in:
Gustavo de Rosa 2024-04-24 06:29:16 -07:00
parent 2abcd4dec3
commit aeb6ae7ebe
1 changed files with 6 additions and 8 deletions

View File

@ -144,10 +144,9 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
else: else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
self.inv_freq = 1.0 / ( inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
ext_factors self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
* self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
@ -186,10 +185,9 @@ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
else: else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
self.inv_freq = 1.0 / ( inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
ext_factors self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
* self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()