fix(phi3): Adds last suggestions to modeling file.

This commit is contained in:
Gustavo de Rosa 2024-04-24 06:24:17 -07:00
parent 06cd06d29f
commit 2abcd4dec3
1 changed files with 26 additions and 63 deletions

View File

@ -129,26 +129,12 @@ class Phi3RotaryEmbedding(nn.Module):
class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(
self,
dim,
short_factor,
long_factor,
original_max_position_embeddings=2048,
max_position_embeddings=2048,
base=10000,
device=None,
):
super().__init__(dim, max_position_embeddings, base, device)
def __init__(self, dim, config, device=None):
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
self.short_factor = short_factor
self.long_factor = long_factor
self.original_max_position_embeddings = original_max_position_embeddings
def _calc_scaling_factor(self, scale):
if scale <= 1.0:
return 1.0
return math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
self.short_factor = config.rope_scaling["short_factor"]
self.long_factor = config.rope_scaling["long_factor"]
self.original_max_position_embeddings = config.original_max_position_embeddings
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
@ -171,36 +157,26 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
scaling_factor = self._calc_scaling_factor(
self.max_position_embeddings / self.original_max_position_embeddings
)
emb = torch.cat((freqs, freqs), dim=-1)
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(
self,
dim,
short_factor,
long_factor,
original_max_position_embeddings=2048,
max_position_embeddings=2048,
base=10000,
device=None,
):
super().__init__(dim, max_position_embeddings, base, device)
def __init__(self, dim, config, device=None):
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
self.short_factor = short_factor
self.long_factor = long_factor
self.original_max_position_embeddings = original_max_position_embeddings
def _calc_scaling_factor(self, scale):
if scale <= 1.0:
return 1.0
return 0.1 * math.log(scale) + 1.0
self.short_factor = config.rope_scaling["short_factor"]
self.long_factor = config.rope_scaling["long_factor"]
self.original_max_position_embeddings = config.original_max_position_embeddings
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
@ -223,10 +199,14 @@ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
scaling_factor = self._calc_scaling_factor(
self.max_position_embeddings / self.original_max_position_embeddings
)
emb = torch.cat((freqs, freqs), dim=-1)
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = 0.1 * math.log(scale) + 1.0
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@ -346,27 +326,10 @@ class Phi3Attention(nn.Module):
)
else:
scaling_type = self.config.rope_scaling["type"]
short_factor = self.config.rope_scaling["short_factor"]
long_factor = self.config.rope_scaling["long_factor"]
if scaling_type == "su":
self.rotary_emb = Phi3SuScaledRotaryEmbedding(
self.head_dim,
short_factor,
long_factor,
max_position_embeddings=self.max_position_embeddings,
original_max_position_embeddings=self.original_max_position_embeddings,
base=self.rope_theta,
)
self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
elif scaling_type == "yarn":
self.rotary_emb = Phi3YarnScaledRotaryEmbedding(
self.head_dim,
short_factor,
long_factor,
max_position_embeddings=self.max_position_embeddings,
original_max_position_embeddings=self.original_max_position_embeddings,
base=self.rope_theta,
)
self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")