This commit is contained in:
Arthur Zucker 2024-05-27 15:24:42 +02:00
parent 8256a73c81
commit 4ead65b86d
1 changed files with 3 additions and 2 deletions

View File

@ -23,6 +23,7 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedConfig
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaForSequenceClassification,
@ -30,7 +31,7 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
repeat_kv,
)
from transformers import PreTrainedConfig
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...modeling_outputs import CausalLMOutputWithPast
@ -91,7 +92,7 @@ class GemmaConfig(PreTrainedConfig):
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):