ah maybe not lol

This commit is contained in:
Arthur Zucker 2024-05-18 11:35:03 +02:00
parent 4aec18187b
commit c45466ef7f
2 changed files with 5 additions and 10 deletions

View File

@ -21,6 +21,8 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -30,7 +32,6 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
@ -44,8 +45,6 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_gemma import GemmaConfig from .configuration_gemma import GemmaConfig
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
if is_flash_attn_2_available(): if is_flash_attn_2_available():
@ -224,13 +223,6 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
LlamaPreTrainedModel,
LlamaSdpaAttention,
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )

View File

@ -196,6 +196,7 @@ class DiffConverterTransformer(CSTTransformer):
return node.with_changes(body=[*new_body, *node.body]) return node.with_changes(body=[*new_body, *node.body])
class SuperTransformer(cst.CSTTransformer): class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,) METADATA_DEPENDENCIES = (ParentNodeProvider,)
@ -266,8 +267,10 @@ class SuperTransformer(cst.CSTTransformer):
return updated_node.with_changes(value=updated_return_value) return updated_node.with_changes(value=updated_return_value)
return updated_node return updated_node
from check_copies import run_ruff from check_copies import run_ruff
if __name__ == "__main__": if __name__ == "__main__":
# Parse the Python file # Parse the Python file
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/diff_gemma.py", "r") as file: with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/diff_gemma.py", "r") as file: