update gemma
This commit is contained in:
parent
f3fe0b340a
commit
35576acfcd
|
@ -1,12 +1,3 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
from transformers.models.llama.modeling_llama import *
|
||||
import torch.nn as nn
|
||||
from transformers.utils import ModelConverter
|
||||
|
||||
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
|
@ -23,39 +14,34 @@ from transformers.utils import ModelConverter
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Gemma model."""
|
||||
from transformers.models.llama.modeling_llama import *
|
||||
import torch.nn as nn
|
||||
from transformers.utils import ModelConverter
|
||||
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
)
|
||||
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
is_flash_attn_2_available,
|
||||
logging,
|
||||
)
|
||||
from ...utils.import_utils import is_torch_fx_available
|
||||
from .configuration_gemma import GemmaConfig
|
||||
|
||||
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half, apply_rotary_pos_emb
|
||||
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
|
@ -234,23 +220,6 @@ class GemmaAttention(nn.Module):
|
|||
|
||||
|
||||
GemmaConverter = ModelConverter(__file__)
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float())
|
||||
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
output = output * (1.0 + self.weight.float())
|
||||
return output.type_as(x)
|
||||
|
||||
GemmaFlashAttention2 = GemmaConverter.register("GemmaFlashAttention2", LlamaFlashAttention2)
|
||||
GemmaSdpaAttention = GemmaConverter.register("GemmaSdpaAttention", LlamaSdpaAttention)
|
||||
|
||||
|
@ -282,5 +251,4 @@ class GemmaModel(LlamaModel):
|
|||
return super().forward(None, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
|
||||
|
||||
|
||||
GemmaConverter.register("GemmaForCausalLM", LlamaForCausalLM)
|
||||
|
||||
GemmaConverter.register("GemmaForCausalLM", LlamaForCausalLM)
|
|
@ -9,15 +9,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
from transformers.models.llama.modeling_llama import *
|
||||
import torch.nn as nn
|
||||
from transformers.utils import ModelConverter
|
||||
|
||||
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
|
@ -34,39 +25,34 @@ from transformers.utils import ModelConverter
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Gemma model."""
|
||||
from transformers.models.llama.modeling_llama import *
|
||||
import torch.nn as nn
|
||||
from transformers.utils import ModelConverter
|
||||
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
)
|
||||
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
is_flash_attn_2_available,
|
||||
logging,
|
||||
)
|
||||
from ...utils.import_utils import is_torch_fx_available
|
||||
from .configuration_gemma import GemmaConfig
|
||||
|
||||
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half, apply_rotary_pos_emb
|
||||
from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
|
@ -244,23 +230,6 @@ class GemmaAttention(nn.Module):
|
|||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float())
|
||||
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
output = output * (1.0 + self.weight.float())
|
||||
return output.type_as(x)
|
||||
|
||||
class GemmaFlashAttention2(GemmaAttention):
|
||||
"""
|
||||
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
|
||||
|
@ -1129,4 +1098,3 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue