ah maybe not lol
This commit is contained in:
parent
4aec18187b
commit
c45466ef7f
|
@ -21,6 +21,8 @@ from typing import List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
|
@ -30,7 +32,6 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
QuestionAnsweringModelOutput,
|
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
@ -44,8 +45,6 @@ from ...utils import (
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from .configuration_gemma import GemmaConfig
|
from .configuration_gemma import GemmaConfig
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
|
@ -224,13 +223,6 @@ import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaDecoderLayer,
|
|
||||||
LlamaFlashAttention2,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
LlamaForSequenceClassification,
|
|
||||||
LlamaModel,
|
|
||||||
LlamaPreTrainedModel,
|
|
||||||
LlamaSdpaAttention,
|
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
|
|
@ -196,6 +196,7 @@ class DiffConverterTransformer(CSTTransformer):
|
||||||
|
|
||||||
return node.with_changes(body=[*new_body, *node.body])
|
return node.with_changes(body=[*new_body, *node.body])
|
||||||
|
|
||||||
|
|
||||||
class SuperTransformer(cst.CSTTransformer):
|
class SuperTransformer(cst.CSTTransformer):
|
||||||
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||||
|
|
||||||
|
@ -266,8 +267,10 @@ class SuperTransformer(cst.CSTTransformer):
|
||||||
return updated_node.with_changes(value=updated_return_value)
|
return updated_node.with_changes(value=updated_return_value)
|
||||||
return updated_node
|
return updated_node
|
||||||
|
|
||||||
|
|
||||||
from check_copies import run_ruff
|
from check_copies import run_ruff
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Parse the Python file
|
# Parse the Python file
|
||||||
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/diff_gemma.py", "r") as file:
|
with open("/Users/arthurzucker/Work/transformers/src/transformers/models/gemma/diff_gemma.py", "r") as file:
|
||||||
|
|
Loading…
Reference in New Issue