Compare commits

...

158 Commits

Author SHA1 Message Date
Arthur Zucker 969cdbf6f7 `# noqa: F841` 2024-05-31 17:47:00 +02:00
Arthur Zucker f667a9ab96 don't push them 2024-05-31 17:46:08 +02:00
Arthur Zucker 63b1bc152d fixes 2024-05-31 17:41:15 +02:00
Arthur Zucker 3eb121c8ef fix some imports 2024-05-31 17:30:25 +02:00
Arthur Zucker d9e1bf4c07 fixup 2024-05-31 17:21:38 +02:00
Arthur Zucker 1839193afc single target assign fix 2024-05-31 17:20:23 +02:00
Arthur Zucker e782306164 autoformat 2024-05-31 11:00:03 +02:00
Arthur Zucker 03ac95c35f add converted dummies 2024-05-31 10:56:40 +02:00
Arthur Zucker 151cd71420 update 2024-05-31 10:46:30 +02:00
Arthur Zucker 2b96630895 update 2024-05-31 10:44:02 +02:00
Arthur Zucker 07c2aa99eb updated for function signatures 2024-05-31 10:14:53 +02:00
Arthur Zucker f124cf9c97 nits 2024-05-31 09:41:33 +02:00
Arthur Zucker d014449cdf update converter and add supper example 2024-05-31 09:38:49 +02:00
Arthur Zucker 0422b9c121 naming nit? 2024-05-30 18:33:39 +02:00
Arthur Zucker 54764f5bba fixes 2024-05-30 17:20:18 +02:00
Arthur Zucker 5797c42f50 less diffs and fix test 2024-05-30 17:15:14 +02:00
Arthur Zucker ecc0aaa3c5 everless diffs 2024-05-30 17:09:03 +02:00
Arthur Zucker fc3c9e7111 ruff format tests src utils --check 2024-05-30 16:50:41 +02:00
Arthur Zucker c27e85c2b3 dummy noy funny 2024-05-30 16:49:18 +02:00
Arthur Zucker d7355db6ae nit 2024-05-30 16:48:48 +02:00
Arthur Zucker e1b0262a9e fix 2024-05-30 16:47:21 +02:00
Arthur Zucker 065cd1afcb remove diff llama 2024-05-30 16:42:23 +02:00
Arthur Zucker fa8a86ccd2 Merge branch 'main' of github.com:huggingface/transformers into diff-converter 2024-05-30 16:39:18 +02:00
Arthur Zucker 2e7499239b fixup 2024-05-30 16:38:05 +02:00
Arthur Zucker 16b6aeda1c add a readme 2024-05-30 16:37:31 +02:00
Arthur Zucker 751c4dbdfd final state 2024-05-30 16:26:35 +02:00
Arthur Zucker 513b933b60 nits 2024-05-30 16:05:11 +02:00
Arthur Zucker 8a85473357 current state 2024-05-30 15:53:56 +02:00
Arthur Zucker 64422e5c06 OUUUUUUF 2024-05-30 15:51:36 +02:00
Arthur Zucker 6207b52f10 state 2024-05-30 10:53:35 +02:00
Arthur Zucker 98c0a91bf7 TODO 2024-05-29 14:09:02 +02:00
Arthur Zucker 5a1cccde71 nit 2024-05-29 13:24:01 +02:00
Arthur Zucker 9828ffc545 update examples 2024-05-29 13:22:10 +02:00
Arthur Zucker 331d8a494a ruff 2024-05-29 13:05:23 +02:00
Arthur Zucker e3e6ccac62 updates 2024-05-29 13:05:07 +02:00
Arthur Zucker 058b6fa71d fixup 2024-05-29 12:42:48 +02:00
Arthur Zucker 0f4e05fa8b even simpler header? 2024-05-29 12:40:05 +02:00
Arthur Zucker dcee16ec11 do cleanup some stuff 2024-05-29 11:34:08 +02:00
Arthur Zucker 85d2a505cb smaller header 2024-05-29 08:11:11 +02:00
Arthur Zucker 1fd611c752 Merge branch 'diff-converter' of github.com:huggingface/transformers into diff-converter 2024-05-29 08:11:02 +02:00
Arthur Zucker ac0dc69bb2 nits 2024-05-29 08:06:25 +02:00
Arthur ab3d4103aa
Apply suggestions from code review
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2024-05-29 08:06:01 +02:00
Arthur Zucker 42f640fba8 update 2024-05-29 07:46:15 +02:00
Arthur Zucker 11280292c3 nits 2024-05-28 17:53:38 +02:00
Arthur Zucker 1836a758f8 fixups 2024-05-28 16:31:09 +02:00
Arthur Zucker 0faa82da98 copy class statements :wink 2024-05-28 16:30:42 +02:00
Arthur Zucker 7ea9bcd3dc better merging strategy 2024-05-28 16:13:43 +02:00
Arthur Zucker f1e1decc92 update 2024-05-28 13:44:33 +02:00
Arthur Zucker 85bccc4176 update caution message 2024-05-28 13:41:24 +02:00
Arthur Zucker 43d8d71eae add warning 2024-05-28 13:38:48 +02:00
Arthur Zucker 28b5596f11 update 2024-05-28 11:48:06 +02:00
Arthur Zucker d1bc03bfd7 update 2024-05-28 11:43:32 +02:00
Arthur Zucker 9d62ba5f53 update 2024-05-28 11:41:47 +02:00
Arthur Zucker 6c423ceaaf correct merge 2024-05-28 10:56:57 +02:00
Arthur Zucker f0068b702c update 2024-05-28 10:55:29 +02:00
Arthur Zucker df19157424 Merge branch 'main' of github.com:huggingface/transformers into diff-converter 2024-05-28 10:51:07 +02:00
Arthur Zucker 6c486574bc nits 2024-05-28 10:50:26 +02:00
Arthur Zucker 494e6bac14 update 2024-05-28 10:45:04 +02:00
Arthur Zucker 54af8877cb nit 2024-05-28 10:33:06 +02:00
Arthur Zucker 099041413b up 2024-05-28 10:32:50 +02:00
Arthur Zucker 1ce5c1b5c7 fixup 2024-05-28 09:38:22 +02:00
Arthur Zucker d6ef9e81e5 use logger info instead of prints 2024-05-28 09:37:55 +02:00
Arthur Zucker 793f63879d order is now expected 2024-05-28 08:35:16 +02:00
Arthur Zucker 7898d3286e nits 2024-05-27 17:04:01 +02:00
Arthur Zucker 80363e3fb7 fixes 2024-05-27 17:02:39 +02:00
Arthur Zucker b888fcdd1d finally 2024-05-27 17:02:14 +02:00
Arthur Zucker 4ead65b86d nit 2024-05-27 15:24:42 +02:00
Arthur Zucker 8256a73c81 new status 2024-05-27 15:21:34 +02:00
Arthur Zucker 49656b30e8 current state 2024-05-27 15:15:13 +02:00
Arthur Zucker 40c5e6da5e current state 2024-05-27 15:10:59 +02:00
Arthur Zucker e62a5bb05b current changes 2024-05-27 14:22:36 +02:00
Arthur Zucker b0853cb588 update 2024-05-27 12:08:46 +02:00
Arthur Zucker 6fb42c2018 finished fining deps 2024-05-25 12:13:40 +02:00
Arthur Zucker 91f45f8ff2 No need for call 2024-05-25 10:05:28 +02:00
Arthur Zucker 585686ed08 updates 2024-05-25 08:28:50 +02:00
Arthur Zucker 337321e8b2 current updates 2024-05-24 17:56:34 +02:00
Arthur Zucker 2df4ec68d8 stash 2024-05-22 15:06:25 +02:00
Arthur Zucker 380b87f0cb less diff! 2024-05-21 12:13:04 +02:00
Arthur Zucker 3abd9f5828 update 2024-05-21 12:11:58 +02:00
Arthur Zucker adc3f920d5 cleanup 2024-05-21 12:06:05 +02:00
Arthur Zucker 10b5591175 updates 2024-05-21 11:47:23 +02:00
Arthur Zucker 0c7e43eb8a remove other diff files 2024-05-21 11:45:59 +02:00
Arthur Zucker 53a4ce871c revert unrelated 2024-05-21 11:43:21 +02:00
Arthur Zucker 6147d3adc2 update! 2024-05-20 16:51:42 +02:00
Arthur Zucker c8e64ede4e non zero exit 2024-05-20 16:48:25 +02:00
Arthur Zucker 43d7809079 okay actual state 2024-05-20 16:39:53 +02:00
Arthur Zucker fdc48d8a64 updates 2024-05-20 16:32:07 +02:00
Arthur Zucker 262c06b35c https://docs.astral.sh/ruff/rules/redefined-while-unused/ fixes the imports, bit needs later version of ruff 2024-05-20 16:06:49 +02:00
Arthur Zucker 29e3381349 updates 2024-05-19 11:09:37 +02:00
Arthur Zucker d5b10f75d1 update llama diff 2024-05-19 11:09:19 +02:00
Arthur Zucker 9dbb22a7c4 synch with main 2024-05-19 11:07:28 +02:00
Arthur Zucker 38286ada82 Merge branch 'main' of github.com:huggingface/transformers into diff-converter 2024-05-19 11:07:11 +02:00
Arthur Zucker 4e8a23e7f7 convert starcoder2 2024-05-19 11:04:26 +02:00
Arthur Zucker 0ced2bc0c8 okay 2024-05-19 10:08:30 +02:00
Arthur Zucker b036a2aff9 conversion of llama 2024-05-18 12:51:40 +02:00
Arthur Zucker 07a90cc324 nit 2024-05-18 12:49:27 +02:00
Arthur Zucker f8587d7a2f for now remove all imports from child. 2024-05-18 11:43:29 +02:00
Arthur Zucker c45466ef7f ah maybe not lol 2024-05-18 11:35:03 +02:00
Arthur Zucker 4aec18187b ruff deals pretty well with imports, let's leave it to him 2024-05-18 11:34:31 +02:00
Arthur Zucker 52b70fdc33 run ruff post script 2024-05-18 11:29:35 +02:00
Arthur Zucker 6c09d23e04 correctly remove duplicate code 2024-05-18 10:52:09 +02:00
Arthur Zucker 292e573321 fixup 2024-05-18 10:35:27 +02:00
Arthur Zucker 65a00cefba deal with duplicates 2024-05-18 10:35:16 +02:00
Arthur Zucker e3be54cf25 🤗 2024-05-18 10:27:32 +02:00
Arthur Zucker 39f696ee4f keep decorators? 2024-05-18 09:43:38 +02:00
Arthur Zucker 67471e6758 process inheritage 2024-05-18 09:38:47 +02:00
Arthur Zucker 075be8c1fc todos 2024-05-17 21:04:08 +02:00
Arthur Zucker e606c513c6 deal with assigns 2024-05-17 20:57:47 +02:00
Arthur Zucker 274ac8801d handle funtions 2024-05-17 20:51:34 +02:00
Arthur Zucker 768801cbac style 2024-05-17 20:38:43 +02:00
Arthur Zucker a5b87808e8 deal with comments 2024-05-17 20:21:29 +02:00
Arthur Zucker 24e072ee71 for now use gemma 2024-05-16 18:31:11 +02:00
Arthur Zucker 39ec61ac2e 🔥 2024-05-16 18:27:36 +02:00
Arthur Zucker f5ebef0deb doc 🚀 2024-05-16 18:24:27 +02:00
Arthur Zucker 6a5264d489 update 2024-05-16 18:17:12 +02:00
Arthur Zucker c804b4bc6d fixup 2024-05-16 18:13:46 +02:00
Arthur Zucker df9e78377b nit 2024-05-16 18:04:05 +02:00
Arthur Zucker c44f82750c clear diffs 2024-05-16 18:03:12 +02:00
Arthur Zucker fca954d6d4 nit 2024-05-16 18:02:20 +02:00
Arthur Zucker 8fe59a5089 ouiiii 2024-05-16 18:00:00 +02:00
Arthur Zucker c9fea750cb update 2024-05-16 12:40:07 +02:00
Arthur Zucker 7b79b4d4b1 current state 2024-05-16 11:19:07 +02:00
Arthur Zucker ce615ff9a5 current state 2024-05-16 11:16:23 +02:00
Arthur Zucker cdb8c6b19d oups 2024-05-15 17:26:18 +02:00
Arthur Zucker 709429a141 updates 2024-05-15 17:21:40 +02:00
Arthur Zucker 35576acfcd update gemma 2024-05-15 17:12:49 +02:00
Arthur Zucker f3fe0b340a updates 2024-05-15 17:10:14 +02:00
Arthur Zucker 3dedb93c45 revert changes done to llama 2024-05-15 16:57:52 +02:00
Arthur Zucker daebeeaf04 updates 2024-05-15 16:51:06 +02:00
Arthur Zucker 45f20f5641 updates 2024-05-15 16:49:18 +02:00
Arthur Zucker eaaf34f303 updates 2024-05-15 16:44:46 +02:00
Arthur Zucker d3ab98e5ae updates 2024-05-15 16:29:39 +02:00
Arthur Zucker d5c00047da updates 2024-05-15 16:07:05 +02:00
Arthur Zucker 8fe406fd17 fix some issues 2024-05-15 15:52:48 +02:00
Arthur Zucker 774a4af6de fix some issues 2024-05-14 14:48:20 +02:00
Arthur Zucker a47468a938 fix some issues 2024-05-14 10:30:48 +02:00
Arthur Zucker 580fbe19e2 update regex patterns 2024-05-14 08:20:45 +02:00
Arthur Zucker 0782ffd2c4 update regex patterns 2024-05-13 17:59:51 +02:00
Arthur Zucker 3a3510ab73 push the actual result 2024-05-10 17:46:50 +02:00
Arthur Zucker ca181ab402 update 2024-05-10 17:45:32 +02:00
Arthur Zucker 8752d35aa8 update 2024-05-10 17:16:52 +02:00
Arthur Zucker 2a654ec763 delete 2024-05-10 17:06:25 +02:00
Arthur Zucker 1aabcc1a73 give some breathing space to the code 2024-05-10 16:57:06 +02:00
Arthur Zucker 22ff159e50 updates with converted versions 2024-05-10 16:50:18 +02:00
Arthur Zucker 1632e0f4bd updates 2024-05-10 16:41:43 +02:00
Arthur Zucker e467d2fede fix rope nits 2024-05-10 14:56:19 +02:00
Arthur Zucker 7545c5f766 add diff file that is the same as the modeling_llama.py 2024-05-10 14:39:14 +02:00
Arthur Zucker 740e5bd35c Merge branch 'main' of github.com:huggingface/transformers into refactoring-new-version 2024-05-10 14:30:37 +02:00
Arthur Zucker 022727c480 nit 2024-04-18 13:47:56 +02:00
Arthur Zucker d68766aa7c persimmon 2024-04-12 19:12:15 +02:00
Arthur Zucker 92b6218e18 attempt diffs for 3 files 2024-04-12 19:03:13 +02:00
Arthur Zucker e08d8eb963 roadmap and nits 2024-04-12 09:31:33 +02:00
Arthur Zucker eb5c2e27e1 oups 2024-04-12 09:18:37 +02:00
Arthur Zucker 1fa297cf1f push the conversion file 2024-04-12 09:16:02 +02:00
Arthur Zucker 0bb0af9ac0 nit 2024-04-12 09:14:41 +02:00
Arthur Zucker bd59e58ca8 update 2024-04-12 08:42:30 +02:00
Arthur Zucker 564813d72e commit regex and result file 2024-04-11 18:10:21 +02:00
Arthur Zucker f02e2fb8cc current working example! 2024-04-11 11:57:31 +02:00
13 changed files with 1315 additions and 62 deletions

View File

@ -0,0 +1,20 @@
# Using the `diff_converter` linter
`pip install libcst` is a must!
# `sh examples/diff-conversion/convert_examples.sh` to get the converted outputs
The diff converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular `diff` file like `diff_gemma.py` into a `single model single file`.
Examples of possible usage are available in the `examples/diff-conversion`, or `diff_gemma` for a full model usage.
`python utils/diff_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model2.py"`
## How it works
We use the `libcst` parser to produce an AST representation of the `diff_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the difference dependencies.
The code from the `diff` file and the class dependency mapping are "merged" to produce the single model single file.
We use ruff to automatically remove the potential duplicate imports.
## Why we use libcst instead of the native AST?
AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst`

View File

@ -0,0 +1,10 @@
#!/bin/bash
# Iterate over each file in the current directory
for file in examples/diff-conversion/diff_*; do
# Check if it's a regular file
if [ -f "$file" ]; then
# Call the Python script with the file name as an argument
python utils/diff_model_converter.py --files_to_parse "$file"
fi
done

View File

@ -0,0 +1,44 @@
from math import log
from typing import List, Optional, Tuple, Union
import torch
from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
def _pre_process_input(input_ids):
print(log(input_ids))
return input_ids
# example where we need some deps and some functions
class DummyModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
input_ids = _pre_process_input(input_ids)
return super().forward(
None,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
)

View File

@ -0,0 +1,14 @@
from transformers.models.llama.configuration_llama import LlamaConfig
# Example where we only want to only add a new config argument and new arg doc
# here there is no `ARG` so we are gonna take parent doc
class MyNewModelConfig(LlamaConfig):
r"""
mlp_bias (`bool`, *optional*, defaults to `False`)
"""
def __init__(self, mlp_bias=True, new_param=0, **super_kwargs):
self.mlp_bias = mlp_bias
self.new_param = new_param
super().__init__(self, **super_kwargs)

View File

@ -0,0 +1,31 @@
from transformers.models.gemma.modeling_gemma import GemmaForSequenceClassification
from transformers.models.llama.configuration_llama import LlamaConfig
# Example where we only want to only modify the docstring
class MyNewModel2Config(LlamaConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
# Example where alllllll the dependencies are fetched to just copy the entire class
class MyNewModel2ForSequenceClassification(GemmaForSequenceClassification):
pass

View File

@ -0,0 +1,30 @@
# Example where we only want to overwrite the defaults of an init
from transformers.models.gemma.configuration_gemma import GemmaConfig
class NewModelConfig(GemmaConfig):
def __init__(
self,
vocab_size=256030,
hidden_size=64,
intermediate_size=90,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
max_position_embeddings=1500,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
):
super().__init__(self)

View File

@ -0,0 +1,38 @@
from typing import List, Optional, Tuple, Union
import torch
from transformers import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
# example where we need some deps and some functions
class SuperModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
out = super().forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
)
out.logits *= 2**4
return out

View File

@ -1,5 +1,12 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,13 +19,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
from transformers import PretrainedConfig
class GemmaConfig(PretrainedConfig):
@ -26,13 +29,9 @@ class GemmaConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
@ -83,16 +82,12 @@ class GemmaConfig(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""

View File

@ -0,0 +1,507 @@
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PretrainedConfig
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
apply_rotary_pos_emb,
repeat_kv,
)
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...modeling_outputs import CausalLMOutputWithPast
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import logging
logger = logging.get_logger(__name__)
class GemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 28):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The legacy activation function. It is overwritten by the `hidden_activation`.
hidden_activation (`str` or `function`, *optional*):
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.hidden_activation = hidden_activation
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
class GemmaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class GemmaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_activation is None:
logger.warning_once(
"`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
"Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
"`config.hidden_activation` if you want to override this behaviour.\n"
"See https://github.com/huggingface/transformers/pull/29402 for more details."
)
config.hidden_activation = "gelu_pytorch_tanh"
hidden_activation = config.hidden_activation
self.act_fn = ACT2FN[hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class GemmaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if self.hidden_size % self.num_heads != 0:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self.rotary_emb = GemmaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class GemmaModel(LlamaModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions
hidden_states = inputs_embeds
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
return super().forward(
causal_mask,
position_ids,
past_key_values,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
cache_position,
input_ids=None,
inputs_embeds=hidden_states,
)
# Example where we ony modify the docstring and call super
class GemmaForCausalLM(LlamaForCausalLM):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class GemmaForSequenceClassification(LlamaForSequenceClassification):
pass
class GemmaForTokenClassification(LlamaForTokenClassification):
pass

View File

@ -1,3 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
@ -13,8 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Gemma model."""
import math
from typing import List, Optional, Tuple, Union
@ -26,10 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
)
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -37,7 +38,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
@ -46,7 +47,6 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.import_utils import is_torch_fx_available
from .configuration_gemma import GemmaConfig
@ -55,25 +55,14 @@ if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
if not is_torch_greater_or_equal_than_1_13:
import torch.fx
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "GemmaConfig"
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
@ -108,7 +97,6 @@ class GemmaRotaryEmbedding(nn.Module):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@ -130,7 +118,35 @@ class GemmaRotaryEmbedding(nn.Module):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
"""GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def forward(self, x, position_ids):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids = position_ids.float() / self.scaling_factor
cos, sin = super().forward(x, position_ids)
return cos, sin
class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
"""GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def forward(self, x, position_ids):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
cos, sin = super().forward(x, position_ids)
return cos, sin
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@ -138,7 +154,6 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
@ -190,7 +205,6 @@ class GemmaMLP(nn.Module):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@ -206,7 +220,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class GemmaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Ignore copy
def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
@ -303,7 +316,6 @@ class GemmaAttention(nn.Module):
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemma
class GemmaFlashAttention2(GemmaAttention):
"""
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
@ -319,7 +331,6 @@ class GemmaFlashAttention2(GemmaAttention):
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
@ -329,13 +340,13 @@ class GemmaFlashAttention2(GemmaAttention):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
output_attentions = False
bsz, q_len, _ = hidden_states.size()
@ -351,8 +362,8 @@ class GemmaFlashAttention2(GemmaAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
@ -397,7 +408,7 @@ class GemmaFlashAttention2(GemmaAttention):
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
@ -503,7 +514,6 @@ class GemmaFlashAttention2(GemmaAttention):
)
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemma
class GemmaSdpaAttention(GemmaAttention):
"""
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@ -511,7 +521,7 @@ class GemmaSdpaAttention(GemmaAttention):
SDPA API.
"""
# Ignore copy
# Adapted from GemmaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
@ -548,8 +558,8 @@ class GemmaSdpaAttention(GemmaAttention):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
@ -584,7 +594,7 @@ class GemmaSdpaAttention(GemmaAttention):
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
@ -598,7 +608,6 @@ GEMMA_ATTENTION_CLASSES = {
}
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
class GemmaDecoderLayer(nn.Module):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
@ -692,9 +701,8 @@ class GemmaPreTrainedModel(PreTrainedModel):
config_class = GemmaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
_no_split_modules = ["GemmaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
@ -713,6 +721,9 @@ class GemmaPreTrainedModel(PreTrainedModel):
module.weight.data[module.padding_idx].zero_()
_CONFIG_FOR_DOC = "GemmaConfig"
GEMMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -821,7 +832,6 @@ class GemmaModel(GemmaPreTrainedModel):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
# Ignore copy
def forward(
self,
input_ids: torch.LongTensor = None,
@ -989,6 +999,8 @@ class GemmaModel(GemmaPreTrainedModel):
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
@ -1020,7 +1032,6 @@ class GemmaModel(GemmaPreTrainedModel):
return causal_mask
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->GEMMA,Llama->Gemma,llama->gemma
class GemmaForCausalLM(GemmaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
@ -1051,7 +1062,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
def get_decoder(self):
return self.model
# Ignore copy
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -1244,7 +1254,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
""",
GEMMA_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->GEMMA,Llama->Gemma
class GemmaForSequenceClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@ -1360,7 +1369,6 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
""",
GEMMA_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Gemma, LLAMA->GEMMA
class GemmaForTokenClassification(GemmaPreTrainedModel):
def __init__(self, config):
super().__init__(config)

View File

@ -17,8 +17,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union

View File

@ -559,8 +559,11 @@ def get_indent(code: str) -> str:
return ""
def run_ruff(code):
command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
def run_ruff(code, check=False):
if check:
command = ["ruff", "check", "-", "--fix", "--exit-zero"]
else:
command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
stdout, _ = process.communicate(input=code.encode())
return stdout.decode()

View File

@ -0,0 +1,555 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import importlib
import re
from typing import Dict
import libcst as cst
from check_copies import run_ruff
from libcst import ClassDef, CSTTransformer, CSTVisitor
from libcst import matchers as m
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
from transformers import logging
logger = logging.get_logger(__name__)
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_diff_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the diff. If any change should be done, please apply the change to the
# diff.py file directly.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
"""
def get_module_source_from_name(module_name: str) -> str:
# Extract the source code from the module name
spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None:
return f"Module {module_name} not found"
with open(spec.origin, "r") as file:
source_code = file.read()
return source_code
class ClassFinder(CSTVisitor):
"""A visitor class which analyses a module, creating a mapping of dependencies between classes and functions.
For example if the visited code has
```python3
def init_value(): return 1
class LlamaModel(PreTrainedModel):
def __init__(self):
super().__init__(self)
self.value = init_value()
```
then the `class_dependency_mapping` should be: `{"LlamaModel":["PreTrainedModel","init_value"], "init_value":[]}
The dependency mapping is updated via the `visit_Name`, `visit_Arg` and `visit_Decorator`. This is very broad, and by
checking the parent node, or the scope of a `cst.Name` or `cst.Arg` or `cst.Decorator` we are able to map the
dependence parent -> child.
When visiting such nodes, we update the dependency of the parent node, to take into account the visited node.
All `visit_XXX` correspond to the code executed when vising the cst.Node of type XXX.
"""
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
def __init__(self, python_module: cst.Module):
# fmt: off
self.python_module: cst.Module = python_module # original cst.Module being visited
self.classes: Dict[str, cst.ClassDef] = {} # stores a mapping from classname to the cst.Node
self.imports = {} # stores all import statements
self.function_def = {} # stores global scope function definition
self.assignments = {} # LLAMA_DOCSTRING
self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"]
# fmt: on
def _update_class_dependency(self, name, value):
"""Update the dependency mapping for `name` with `value` by appending the previous
dependencies to the new `value`.
"""
dep = set(self.class_dependency_mapping.get(value, set()))
dep |= set(self.class_dependency_mapping.get(name, {})) | set({value})
self.class_dependency_mapping[name] = dep
def visit_ClassDef(self, node: ClassDef) -> None:
"""We don't have non global scope class defs in transformers. Here we add the inheritance dependencies"""
self.classes[node.name.value] = node
for k in node.bases: # deal with inheritance
base_name = self.python_module.code_for_node(k)
self._update_class_dependency(node.name.value, base_name)
def visit_SimpleStatementLine(self, node):
"""
Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements
are extracted and saved in their corresponding dict. They are then used when updating dependency mappings.
"""
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(
self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()
):
self.assignments[node.body[0].targets[0].target.value] = node
if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
self.imports[node.body[0].names] = node
def visit_FunctionDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.function_def[node.name.value] = node
def leave_If(self, node):
for stmt in node.body.body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
self.imports[stmt.body[0].names] = node
def leave_Name(self, node):
if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys():
parent = self.get_metadata(cst.metadata.ScopeProvider, node)
if not isinstance(parent, cst.metadata.scope_provider.GlobalScope):
self._update_class_dependency(parent._name_prefix.split(".")[0], node.value)
def leave_Arg(self, node):
if m.matches(node.value, m.Name()):
parent = self.get_metadata(ParentNodeProvider, node)
if m.matches(parent, m.ClassDef()) and parent.bases:
self._update_class_dependency(parent.name.value, node.value.value)
def leave_Dict(self, node):
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent, m.Assign(targets=[m.AssignTarget()])):
name = parent.targets[0].target.value
if name in self.assignments:
for k in node.elements:
dep_name = k.value.value
if dep_name in self.classes:
self._update_class_dependency(name, dep_name)
def leave_Decorator(self, node):
if hasattr(node.decorator, "args"):
for k in node.decorator.args:
if k.value.value in self.assignments:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
scope = self.get_metadata(cst.metadata.ScopeProvider, node)
name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value
self._update_class_dependency(name, k.value.value)
def leave_Module(self, node):
"""When leaving the module, we store the position of each global scoped node (Assigns, function def and class def)
to allow sorting the dependencies based on their position in the code. We use the PositionProvider metadata wrapper for this.
"""
self.global_nodes = {**self.assignments, **self.classes, **self.function_def}
# now sort the class dependency_mapping based on the position of the nodes
self.class_start_line = {}
for id, node in self.global_nodes.items():
self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
"""A transformer that replaces `old_name` with `new_name` in comments, string and any references.
It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING.
Supported renaming patterns:
- llama -> my_new_model and my_new_model -> llama
- Llama -> MyNewModel and MyNewModel -> Llama
- LLAMA -> MY_NEW_MODEL and MY_NEW_MODEL -> LLAMA
- LLaMa -> MyNewModel abd MyNewModel -> Llama
"""
def __init__(self, old_name, new_name):
super().__init__()
self.old_name = old_name
self.new_name = new_name
self.default_name = "".join(x.title() for x in new_name.split("_"))
self.patterns = {
old_name: new_name,
old_name.upper(): new_name.upper(),
"".join(x.title() for x in old_name.split("_")): self.default_name,
}
def preserve_case_replace(self, text):
# Create a regex pattern to match all variations
regex_pattern = "|".join(re.escape(key) for key in self.patterns.keys())
compiled_regex = re.compile(regex_pattern, re.IGNORECASE)
def replace(match):
word = match.group(0)
return self.patterns.get(word, self.default_name)
return compiled_regex.sub(replace, text)
@m.leave(m.Name() | m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node):
update = self.preserve_case_replace(updated_node.value)
return updated_node.with_changes(value=update)
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma"):
"""Helper function to rename and then parse a source file using the ClassFinder"""
transformer = ReplaceNameTransformer(old_id, new_id)
new_module = module.visit(transformer)
wrapper = MetadataWrapper(new_module)
class_finder = ClassFinder(new_module)
wrapper.visit(class_finder)
return class_finder
DOCSTRING_NODE = m.SimpleStatementLine(
body=[
m.Expr(
value=m.SimpleString(
# match anything between """ """
value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
)
)
]
)
class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self, python_module: cst.Module, original_methods, updated_methods):
self.python_module = python_module
self.original_methods = original_methods
self.updated_methods = updated_methods
def update_body(self, existing_body, new_statements):
"""
Helper method to update the body by removing duplicates before adding new statements.
"""
deduplicated_new_body = []
existing_nodes = {
self.python_module.code_for_node(node).strip() for node in new_statements if isinstance(node, cst.CSTNode)
}
for stmt in existing_body:
if self.python_module.code_for_node(stmt).strip() not in existing_nodes:
if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring:
continue
deduplicated_new_body.append(stmt)
existing_nodes.add(stmt)
else:
logger.info(f"\nFound duplicate {self.python_module.code_for_node(stmt)}")
return deduplicated_new_body
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
"""Updates the body of the input `node`'s `func_name` function by replacing calls
to super().func_name() with the source code of the parent class' `func_name`.
It keeps everything that is defined before `super().func_name()`.
"""
new_body = []
self.has_docstring = False
for expr in node.body:
self.has_docstring = m.matches(node.body[0], DOCSTRING_NODE)
if m.matches(
expr,
m.SimpleStatementLine(
body=[
m.Return(
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
)
| m.Expr(
value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
)
]
),
):
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
else:
new_body.append(expr)
return node.with_changes(body=new_body)
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
if updated_node.name.value in self.updated_methods:
name = updated_node.name.value
new_body = self.replace_super_calls(updated_node.body, name)
return updated_node.with_changes(body=new_body, params=updated_node.params)
return updated_node
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
""" "When a return statement is reached, it is replaced with the unrolled super code"""
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
func_def = self.get_metadata(ParentNodeProvider, original_node)
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods:
updated_return_value = updated_node.value.with_changes(
args=[
cst.Arg(
value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))])
)
]
)
return updated_node.with_changes(value=updated_return_value)
return updated_node
def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str):
"""
Given the `class_name`, the `updated_node`'s call to super are unpacked.
| ```python | | ```python
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
| def __init__(self): | | def __init__(self):
Going from: | self.dropout = 0.2 | to: | self.dropout = 0.2
| super().__init__() | | super().__init__(config)
| ``` | | self.padding_idx = config.pad_token_id
| self.vocab_size = config.vocab_size
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
| self.layers = nn.ModuleList(
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
| )
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
| self.gradient_checkpointing = False
| # Initialize weights and apply final processing
| self.post_init()
| ```
"""
original_node = class_finder.classes[class_name]
original_methods = {f.name.value if hasattr(f, "name") else f: f for f in original_node.body.body}
updated_methods = {f.name.value if hasattr(f, "name") else f: f for f in updated_node.body.body}
end_meth = []
for name, func in original_methods.items():
if name in updated_methods and updated_methods[name] is not None:
new_params = updated_methods[name].params
# Replace the method in the replacement class, preserving decorators
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
if kwarg_name and kwarg_name.name.value == "super_kwargs":
parent_params = {k.name.value: k for k in func.params.params}
parent_params.update({k.name.value: k for k in new_params.params[1:]})
new_params = new_params.with_changes(
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
)
func = func.with_changes(body=updated_methods[name].body, params=new_params)
end_meth.append(func)
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
temp_module = cst.Module(body=[result_node])
new_module = MetadataWrapper(temp_module)
new_replacement_class = new_module.visit(SuperTransformer(temp_module, original_methods, updated_methods))
new_replacement_body = new_replacement_class.body[0].body # get the indented block
return original_node.with_changes(body=new_replacement_body)
class DiffConverterTransformer(CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
def __init__(self, python_module, new_name):
super().__init__()
self.model_name = (
new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3`
)
# fmt: off
self.python_module = python_module # we store the original module to use `code_for_node`
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"}
self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama"
self.new_body = {} # store the new body, all global scope nodes should be added here
self.inserted_deps = [] # nodes inserted via super dependency
self.all_imports = [] # just stores all of the imports
self.global_scope_index = 0
# fmt: on
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
"""When visiting imports from `transformers.models.xxx` we need to:
1. Get the original source code
2. Parse it into an AST Tree
3. Add this import to `self.transformers_imports` as visited to not parse it twice
"""
import_statement = self.python_module.code_for_node(node.module)
if m.matches(node.module, m.Attribute()):
for imported_ in node.names:
_import = re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", import_statement)
if _import:
source = _import.groups()[0]
if source == "modeling" and "Config" in self.python_module.code_for_node(imported_):
raise ValueError(
f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead"
)
if import_statement not in self.transformers_imports:
source_code = get_module_source_from_name(import_statement)
tree = cst.parse_module(source_code)
self.transformers_imports[import_statement] = tree
imported_class = self.python_module.code_for_node(imported_.name)
self.imported_mapping[imported_class] = import_statement
def leave_FunctionDef(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
self.global_scope_index += 100
self.new_body[node.name.value] = {"insert_idx": self.global_scope_index, "node": node}
return node
def leave_SimpleStatementLine(self, original_node, updated_node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
if parent_node not in self.all_imports:
self.all_imports.append(updated_node)
return updated_node
elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement):
return cst.RemoveFromParent()
if parent_node not in self.all_imports:
self.all_imports.append(updated_node)
return updated_node
self.global_scope_index += 100
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Assign()])):
# TODO This only works for single target assigns!
node_name = updated_node.body[0].targets[0].target.value
else:
node_name = self.python_module.code_for_node(updated_node.body[0])
self.new_body[node_name] = {
"insert_idx": self.global_scope_index,
"node": updated_node,
}
return updated_node
def leave_ClassDef(self, original_node, updated_node):
"""
1. Filter the `base` classes of this class
If they are from `transformers.models.xx` then:
- take the AST tree of the module it comes from and parse it with a `ClassFinder`.
- rename all every instance of `old_name` (llama) to `new_name` (gemma)
2. We insert the modules which the inherited base depends on. This has to be done in
the order of the dependencies. If on is already in the new_body (because it's defined in the diff file)
then we remove it from the new body to add it again in the correct order.
3. Replace the calls to `super().xxxx` merging parent code
"""
class_name = original_node.name.value
bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping]
self.global_scope_index += 100
for super_class in bases:
if super_class not in self.imported_mapping:
raise ImportError(
f"{super_class} was not imported using `from transformers.models.xxxxx.modeling_xxxx import {super_class}"
)
super_file_name = self.imported_mapping[super_class] # we need to get the parsed tree
model_name = re.search(r"_(\S*)", super_file_name)
if model_name:
model_name = model_name.groups()[0]
else:
raise ValueError(
f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
)
if super_file_name not in self.visited_module: # only extract classes once
class_finder = find_classes_in_file(
self.transformers_imports[super_file_name], model_name, self.model_name
)
self.visited_module[super_file_name] = class_finder
else: # we are re-using the previously parsed data
class_finder = self.visited_module[super_file_name]
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
start_insert_idx = self.global_scope_index
for dependency, _ in list_dependencies:
node = class_finder.global_nodes.get(dependency, None)
if node is not None:
if dependency not in self.new_body:
start_insert_idx -= 1
self.new_body[dependency] = {"insert_idx": start_insert_idx, "node": node}
elif dependency not in self.inserted_deps:
# make sure the node is written after it's dependencies
start_insert_idx = self.new_body[dependency]["insert_idx"] - 1
self.inserted_deps.append(dependency)
if len(list_dependencies) > 0:
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
if "Config" in class_name:
self.config_body = [updated_node]
else:
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
return updated_node
def leave_If(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
full_statement = self.python_module.code_for_node(original_node.test)
if re.search(r"[\s\S]*is_.*available", full_statement):
self.all_imports.append(node)
elif full_statement not in self.new_body:
self.new_body[node] = {"insert_idx": self.global_scope_index, "node": node}
return node
def leave_Module(self, original_node: cst.Assign, node):
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
dependency_imports = {}
for visiter in self.visited_module.values():
dependency_imports.update({self.python_module.code_for_node(k): k for k in visiter.imports.values()})
if hasattr(self, "config_body"):
self.config_body = list(imports.values()) + self.config_body
dependency_imports.update(imports)
new_body = list(dependency_imports.values())
if len(self.new_body.keys()) > 0:
new_body += [k[1]["node"] for k in sorted(self.new_body.items(), key=lambda x: x[1]["insert_idx"])]
else:
new_body = []
return node.with_changes(body=[*new_body])
def convert_file(diff_file, cst_transformers=None):
model_name = re.search(r"diff_(.*)(?=\.py$)", diff_file).groups()[0]
# Parse the Python file
with open(diff_file, "r") as file:
code = file.read()
module = cst.parse_module(code)
wrapper = MetadataWrapper(module)
if cst_transformers is None:
cst_transformers = DiffConverterTransformer(module, model_name)
new_mod = wrapper.visit(cst_transformers)
ruffed_code = run_ruff(new_mod.code, True)
formatted_code = run_ruff(ruffed_code, False)
if len(formatted_code.strip()) > 0:
with open(diff_file.replace("diff_", "modeling_"), "w") as f:
f.write(AUTO_GENERATED_MESSAGE + formatted_code)
if hasattr(cst_transformers, "config_body"):
config_module = cst.Module(body=[*cst_transformers.config_body], header=new_mod.header)
with open(diff_file.replace("diff_", "configuration_"), "w") as f:
ruffed_code = run_ruff(config_module.code, True)
formatted_code = run_ruff(ruffed_code, False)
f.write(AUTO_GENERATED_MESSAGE + formatted_code)
# TODO optimize by re-using the class_finder
return cst_transformers
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model.py"],
nargs="+",
help="A list of `diff_xxxx` files that should be converted to single model file",
)
args = parser.parse_args()
if args.files_to_parse == ["all"]:
args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True)
for file_name in args.files_to_parse:
print(f"Converting {file_name} to a single model single file format")
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
converter = convert_file(file_name)