diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 6936c1b342..472daa17ed 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -56,15 +56,6 @@ if is_flash_attn_2_available(): from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -"""PyTorch Gemma model.""" - - -import torch -import torch.utils.checkpoint - logger = logging.get_logger(__name__) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 761c8a6222..eacff7dd09 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -23,8 +23,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch LLaMA model.""" - import math from typing import List, Optional, Tuple, Union