https://docs.astral.sh/ruff/rules/redefined-while-unused/ fixes the imports, bit needs later version of ruff
This commit is contained in:
parent
29e3381349
commit
262c06b35c
|
@ -20,6 +20,8 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
|
@ -29,7 +31,6 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
|
|||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
|
@ -43,8 +44,6 @@ from ...utils import (
|
|||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_gemma import GemmaConfig
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
|
@ -215,31 +214,16 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||
|
||||
""" PyTorch Gemma model."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaFlashAttention2,
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
LlamaSdpaAttention,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_outputs import CausalLMOutputWithPast
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import logging
|
||||
from .configuration_gemma import GemmaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
|
Loading…
Reference in New Issue