Add safeguards for CUDA kernel load in Deformable DETR (#19037)
This commit is contained in:
parent
31be02f14b
commit
0e24548081
|
@ -41,7 +41,7 @@ from ...file_utils import (
|
||||||
)
|
)
|
||||||
from ...modeling_outputs import BaseModelOutput
|
from ...modeling_outputs import BaseModelOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import is_ninja_available, logging
|
||||||
from .configuration_deformable_detr import DeformableDetrConfig
|
from .configuration_deformable_detr import DeformableDetrConfig
|
||||||
from .load_custom import load_cuda_kernels
|
from .load_custom import load_cuda_kernels
|
||||||
|
|
||||||
|
@ -49,9 +49,13 @@ from .load_custom import load_cuda_kernels
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
# Move this to not compile only when importing, this needs to happen later, like in __init__.
|
# Move this to not compile only when importing, this needs to happen later, like in __init__.
|
||||||
if is_torch_cuda_available():
|
if is_torch_cuda_available() and is_ninja_available():
|
||||||
logger.info("Loading custom CUDA kernels...")
|
logger.info("Loading custom CUDA kernels...")
|
||||||
MultiScaleDeformableAttention = load_cuda_kernels()
|
try:
|
||||||
|
MultiScaleDeformableAttention = load_cuda_kernels()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
|
||||||
|
MultiScaleDeformableAttention = None
|
||||||
else:
|
else:
|
||||||
MultiScaleDeformableAttention = None
|
MultiScaleDeformableAttention = None
|
||||||
|
|
||||||
|
|
|
@ -98,6 +98,7 @@ from .import_utils import (
|
||||||
is_in_notebook,
|
is_in_notebook,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
|
is_ninja_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_pandas_available,
|
is_pandas_available,
|
||||||
is_phonemizer_available,
|
is_phonemizer_available,
|
||||||
|
|
|
@ -471,6 +471,10 @@ def is_apex_available():
|
||||||
return importlib.util.find_spec("apex") is not None
|
return importlib.util.find_spec("apex") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_ninja_available():
|
||||||
|
return importlib.util.find_spec("ninja") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_ipex_available():
|
def is_ipex_available():
|
||||||
def get_major_and_minor_from_version(full_version):
|
def get_major_and_minor_from_version(full_version):
|
||||||
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
|
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
|
||||||
|
|
Loading…
Reference in New Issue