diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 9e41e38eb1..5a77252cf4 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -242,12 +242,12 @@ class Phi3MLP(nn.Module): self.activation_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - y = self.gate_up_proj(hidden_states) + up_states = self.gate_up_proj(hidden_states) - gate, y = y.chunk(2, dim=-1) - y = y * self.activation_fn(gate) + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) - return self.down_proj(y) + return self.down_proj(up_states) # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi