ah maybe not lol
This commit is contained in:
parent
4aec18187b
commit
c45466ef7f
|
@ -21,6 +21,8 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
|
@ -30,7 +32,6 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
|
|||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
|
@ -44,8 +45,6 @@ from ...utils import (
|
|||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_gemma import GemmaConfig
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
|
@ -224,13 +223,6 @@ import torch.utils.checkpoint
|
|||
from torch import nn
|
||||
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaFlashAttention2,
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
LlamaSdpaAttention,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
|
|
@ -196,6 +196,7 @@ class DiffConverterTransformer(CSTTransformer):
|
|||
|
||||
return node.with_changes(body=[*new_body, *node.body])
|
||||
|
||||
|
||||
class SuperTransformer(cst.CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||
|
||||
|
@ -266,8 +267,10 @@ class SuperTransformer(cst.CSTTransformer):
|
|||
return updated_node.with_changes(value=updated_return_value)
|
||||
return updated_node
|
||||
|
||||
|
||||
from check_copies import run_ruff
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse the Python file
|
||||
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/diff_gemma.py", "r") as file:
|
||||
|
|
Loading…
Reference in New Issue