final state

This commit is contained in:
Arthur Zucker 2024-05-30 16:26:35 +02:00
parent 513b933b60
commit 751c4dbdfd
4 changed files with 13 additions and 15 deletions

View File

@ -13,8 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Gemma model."""
import math
from typing import List, Optional, Tuple, Union

View File

@ -19,7 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
@ -54,7 +53,6 @@ from .configuration_gemma import GemmaConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
"""PyTorch Gemma model."""
logger = logging.get_logger(__name__)
@ -728,7 +726,6 @@ class GemmaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
_CONFIG_FOR_DOC = "GemmaConfig"

View File

@ -54,6 +54,7 @@ from ...utils import (
)
from .configuration_llama import LlamaConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa

View File

@ -349,17 +349,19 @@ class DiffConverterTransformer(CSTTransformer):
return node
def leave_SimpleStatementLine(self, original_node, updated_node):
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
self.all_imports.append(updated_node)
return updated_node
elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement):
return cst.RemoveFromParent()
self.all_imports.append(updated_node)
return updated_node
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()):
if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])):
if parent_node not in self.all_imports:
self.all_imports.append(updated_node)
return updated_node
elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])):
full_statement = self.python_module.code_for_node(updated_node.body[0].module)
if re.search(r"transformers\.models\..*\.(modeling|configuration)_.*", full_statement):
return cst.RemoveFromParent()
if parent_node not in self.all_imports:
self.all_imports.append(updated_node)
return updated_node
self.global_scope_index += 100
self.new_body[self.python_module.code_for_node(updated_node.body[0])] = {
"insert_idx": self.global_scope_index,
@ -451,7 +453,7 @@ def convert_file(diff_file, cst_transformers=None):
if cst_transformers is None:
cst_transformers = DiffConverterTransformer(module)
new_mod = wrapper.visit(cst_transformers)
ruffed_code = new_mod.code # run_ruff(new_mod.code, True)
ruffed_code = run_ruff(new_mod.code, True)
if len(ruffed_code) > 0:
with open(diff_file.replace("diff_", "modeling_"), "w") as f:
f.write(AUTO_GENERATED_MESSAGE + ruffed_code)