https://docs.astral.sh/ruff/rules/redefined-while-unused/ fixes the imports, bit needs later version of ruff

This commit is contained in:
Arthur Zucker 2024-05-20 16:06:49 +02:00
parent 29e3381349
commit 262c06b35c
1 changed files with 2 additions and 18 deletions

View File

@ -20,6 +20,8 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -29,7 +31,6 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
@ -43,8 +44,6 @@ from ...utils import (
replace_return_docstrings,
)
from .configuration_gemma import GemmaConfig
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
if is_flash_attn_2_available():
@ -215,31 +214,16 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
""" PyTorch Gemma model."""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaFlashAttention2,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
LlamaPreTrainedModel,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...modeling_outputs import CausalLMOutputWithPast
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import logging
from .configuration_gemma import GemmaConfig
logger = logging.get_logger(__name__)