fix(phi3): Adds last suggestions to modeling file.
This commit is contained in:
parent
06cd06d29f
commit
2abcd4dec3
|
@ -129,26 +129,12 @@ class Phi3RotaryEmbedding(nn.Module):
|
|||
|
||||
|
||||
class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
short_factor,
|
||||
long_factor,
|
||||
original_max_position_embeddings=2048,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
):
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
def __init__(self, dim, config, device=None):
|
||||
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
|
||||
|
||||
self.short_factor = short_factor
|
||||
self.long_factor = long_factor
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
|
||||
def _calc_scaling_factor(self, scale):
|
||||
if scale <= 1.0:
|
||||
return 1.0
|
||||
return math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
||||
self.short_factor = config.rope_scaling["short_factor"]
|
||||
self.long_factor = config.rope_scaling["long_factor"]
|
||||
self.original_max_position_embeddings = config.original_max_position_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
|
@ -171,36 +157,26 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
|||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
scaling_factor = self._calc_scaling_factor(
|
||||
self.max_position_embeddings / self.original_max_position_embeddings
|
||||
)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
||||
if scale <= 1.0:
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
||||
|
||||
cos = emb.cos() * scaling_factor
|
||||
sin = emb.sin() * scaling_factor
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
short_factor,
|
||||
long_factor,
|
||||
original_max_position_embeddings=2048,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
):
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
def __init__(self, dim, config, device=None):
|
||||
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
|
||||
|
||||
self.short_factor = short_factor
|
||||
self.long_factor = long_factor
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
|
||||
def _calc_scaling_factor(self, scale):
|
||||
if scale <= 1.0:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
self.short_factor = config.rope_scaling["short_factor"]
|
||||
self.long_factor = config.rope_scaling["long_factor"]
|
||||
self.original_max_position_embeddings = config.original_max_position_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
|
@ -223,10 +199,14 @@ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
|||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
scaling_factor = self._calc_scaling_factor(
|
||||
self.max_position_embeddings / self.original_max_position_embeddings
|
||||
)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
||||
if scale <= 1.0:
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 0.1 * math.log(scale) + 1.0
|
||||
|
||||
cos = emb.cos() * scaling_factor
|
||||
sin = emb.sin() * scaling_factor
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
@ -346,27 +326,10 @@ class Phi3Attention(nn.Module):
|
|||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
short_factor = self.config.rope_scaling["short_factor"]
|
||||
long_factor = self.config.rope_scaling["long_factor"]
|
||||
|
||||
if scaling_type == "su":
|
||||
self.rotary_emb = Phi3SuScaledRotaryEmbedding(
|
||||
self.head_dim,
|
||||
short_factor,
|
||||
long_factor,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
original_max_position_embeddings=self.original_max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
|
||||
elif scaling_type == "yarn":
|
||||
self.rotary_emb = Phi3YarnScaledRotaryEmbedding(
|
||||
self.head_dim,
|
||||
short_factor,
|
||||
long_factor,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
original_max_position_embeddings=self.original_max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
|
|
Loading…
Reference in New Issue