Add safeguards for CUDA kernel load in Deformable DETR (#19037)

This commit is contained in:
Sylvain Gugger 2022-09-14 13:28:40 -04:00 committed by GitHub
parent 31be02f14b
commit 0e24548081
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 3 deletions

View File

@ -41,7 +41,7 @@ from ...file_utils import (
) )
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import is_ninja_available, logging
from .configuration_deformable_detr import DeformableDetrConfig from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels from .load_custom import load_cuda_kernels
@ -49,9 +49,13 @@ from .load_custom import load_cuda_kernels
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Move this to not compile only when importing, this needs to happen later, like in __init__. # Move this to not compile only when importing, this needs to happen later, like in __init__.
if is_torch_cuda_available(): if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...") logger.info("Loading custom CUDA kernels...")
MultiScaleDeformableAttention = load_cuda_kernels() try:
MultiScaleDeformableAttention = load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
MultiScaleDeformableAttention = None
else: else:
MultiScaleDeformableAttention = None MultiScaleDeformableAttention = None

View File

@ -98,6 +98,7 @@ from .import_utils import (
is_in_notebook, is_in_notebook,
is_ipex_available, is_ipex_available,
is_librosa_available, is_librosa_available,
is_ninja_available,
is_onnx_available, is_onnx_available,
is_pandas_available, is_pandas_available,
is_phonemizer_available, is_phonemizer_available,

View File

@ -471,6 +471,10 @@ def is_apex_available():
return importlib.util.find_spec("apex") is not None return importlib.util.find_spec("apex") is not None
def is_ninja_available():
return importlib.util.find_spec("ninja") is not None
def is_ipex_available(): def is_ipex_available():
def get_major_and_minor_from_version(full_version): def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)