update gemma

This commit is contained in:
Arthur Zucker 2024-05-15 17:12:49 +02:00
parent f3fe0b340a
commit 35576acfcd
2 changed files with 15 additions and 79 deletions

View File

@ -1,12 +1,3 @@
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
from transformers.models.llama.modeling_llama import *
import torch.nn as nn
from transformers.utils import ModelConverter
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
@ -23,39 +14,34 @@ from transformers.utils import ModelConverter
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Gemma model."""
from transformers.models.llama.modeling_llama import *
import torch.nn as nn
from transformers.utils import ModelConverter
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
)
from ...modeling_outputs import BaseModelOutputWithPast
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
is_flash_attn_2_available,
logging,
)
from ...utils.import_utils import is_torch_fx_available
from .configuration_gemma import GemmaConfig
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half, apply_rotary_pos_emb
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
logger = logging.get_logger(__name__)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@ -234,23 +220,6 @@ class GemmaAttention(nn.Module):
GemmaConverter = ModelConverter(__file__)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
GemmaFlashAttention2 = GemmaConverter.register("GemmaFlashAttention2", LlamaFlashAttention2)
GemmaSdpaAttention = GemmaConverter.register("GemmaSdpaAttention", LlamaSdpaAttention)
@ -282,5 +251,4 @@ class GemmaModel(LlamaModel):
return super().forward(None, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
GemmaConverter.register("GemmaForCausalLM", LlamaForCausalLM)
GemmaConverter.register("GemmaForCausalLM", LlamaForCausalLM)

View File

@ -9,15 +9,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
from transformers.models.llama.modeling_llama import *
import torch.nn as nn
from transformers.utils import ModelConverter
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
@ -34,39 +25,34 @@ from transformers.utils import ModelConverter
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Gemma model."""
from transformers.models.llama.modeling_llama import *
import torch.nn as nn
from transformers.utils import ModelConverter
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
)
from ...modeling_outputs import BaseModelOutputWithPast
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
is_flash_attn_2_available,
logging,
)
from ...utils.import_utils import is_torch_fx_available
from .configuration_gemma import GemmaConfig
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half, apply_rotary_pos_emb
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
logger = logging.get_logger(__name__)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
@ -244,23 +230,6 @@ class GemmaAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
class GemmaFlashAttention2(GemmaAttention):
"""
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
@ -1129,4 +1098,3 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
)
return reordered_past