add mlp bias for llama models (#30031)

* add bias

* fix quality
This commit is contained in:
Mayank Mishra 2024-05-03 05:02:17 -04:00 committed by GitHub
parent a0e77a1f6b
commit 425e1a0426
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 8 additions and 5 deletions

View File

@ -161,7 +161,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -94,10 +94,12 @@ class LlamaConfig(PretrainedConfig):
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
```python
>>> from transformers import LlamaModel, LlamaConfig
@ -137,6 +139,7 @@ class LlamaConfig(PretrainedConfig):
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
@ -161,6 +164,7 @@ class LlamaConfig(PretrainedConfig):
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
super().__init__(
pad_token_id=pad_token_id,

View File

@ -214,9 +214,9 @@ class LlamaMLP(nn.Module):
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):