fix(phi3): Uses up_states instead of y in Phi3MLP.

This commit is contained in:
Gustavo de Rosa 2024-04-23 13:40:45 -07:00
parent d5aed89bd3
commit 3a24a1d4d2
1 changed files with 4 additions and 4 deletions

View File

@ -242,12 +242,12 @@ class Phi3MLP(nn.Module):
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
y = self.gate_up_proj(hidden_states)
up_states = self.gate_up_proj(hidden_states)
gate, y = y.chunk(2, dim=-1)
y = y * self.activation_fn(gate)
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.activation_fn(gate)
return self.down_proj(y)
return self.down_proj(up_states)
# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi