fix(phi3): Uses up_states instead of y in Phi3MLP.
This commit is contained in:
parent
d5aed89bd3
commit
3a24a1d4d2
|
@ -242,12 +242,12 @@ class Phi3MLP(nn.Module):
|
|||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
y = self.gate_up_proj(hidden_states)
|
||||
up_states = self.gate_up_proj(hidden_states)
|
||||
|
||||
gate, y = y.chunk(2, dim=-1)
|
||||
y = y * self.activation_fn(gate)
|
||||
gate, up_states = up_states.chunk(2, dim=-1)
|
||||
up_states = up_states * self.activation_fn(gate)
|
||||
|
||||
return self.down_proj(y)
|
||||
return self.down_proj(up_states)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
|
||||
|
|
Loading…
Reference in New Issue