Better check for packages availability (#23163)

* Better check for packages availability

* amend _optimumneuron_available

* amend torch_version

* amend PIL detection and lint

* lint

* amend _faiss_available

* remove overloaded signatures of _is_package_available

* fix sklearn and decord detection

* remove unused checks

* revert
This commit is contained in:
Alessandro Pietro Bardelli 2023-05-11 19:52:22 +02:00 committed by GitHub
parent d51296d9c2
commit 83eda6435e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 168 additions and 273 deletions

View File

@ -72,6 +72,7 @@ from .utils import (
get_cached_models,
get_file_from_repo,
get_full_repo_name,
get_torch_version,
has_file,
http_user_agent,
is_apex_available,
@ -125,5 +126,4 @@ from .utils import (
to_numpy,
to_py_obj,
torch_only_method,
torch_version,
)

View File

@ -232,9 +232,9 @@ class OnnxConfig(ABC):
`bool`: Whether the installed version of PyTorch is compatible with the model.
"""
if is_torch_available():
from transformers.utils import torch_version
from transformers.utils import get_torch_version
return torch_version >= self.torch_onnx_minimum_version
return get_torch_version() >= self.torch_onnx_minimum_version
else:
return False

View File

@ -334,12 +334,12 @@ def export(
preprocessor = tokenizer
if is_torch_available():
from ..utils import torch_version
from ..utils import get_torch_version
if not config.is_torch_support_available:
logger.warning(
f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version},"
f" got: {torch_version}"
f" got: {get_torch_version()}"
)
if is_torch_available() and issubclass(type(model), PreTrainedModel):

View File

@ -99,6 +99,7 @@ from .import_utils import (
_LazyModule,
ccl_version,
direct_transformers_import,
get_torch_version,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
@ -170,7 +171,6 @@ from .import_utils import (
is_vision_available,
requires_backends,
torch_only_method,
torch_version,
)

View File

@ -25,7 +25,6 @@ import warnings
from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch
from packaging import version
from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx._compatibility import compatibility
@ -54,8 +53,13 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_peft_available, is_torch_fx_available
from ..utils.versions import importlib_metadata
from ..utils import (
ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION,
get_torch_version,
is_peft_available,
is_torch_fx_available,
)
if is_peft_available():
@ -737,9 +741,8 @@ class HFTracer(Tracer):
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
if not is_torch_fx_available():
torch_version = version.parse(importlib_metadata.version("torch"))
raise ImportError(
f"Found an incompatible version of torch. Found version {torch_version}, but only version "
f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version "
f"{TORCH_FX_REQUIRED_VERSION} is supported."
)

View File

@ -25,7 +25,7 @@ from collections import OrderedDict
from functools import lru_cache
from itertools import chain
from types import ModuleType
from typing import Any
from typing import Any, Tuple, Union
from packaging import version
@ -35,6 +35,24 @@ from .versions import importlib_metadata
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
package_version = importlib_metadata.version(pkg_name)
package_exists = True
except importlib_metadata.PackageNotFoundError:
package_exists = False
logger.debug(f"Detected {pkg_name} version {package_version}")
if return_version:
return package_exists, package_version
else:
return package_exists
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
@ -44,26 +62,80 @@ USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
_bitsandbytes_available = _is_package_available("bitsandbytes")
_bs4_available = _is_package_available("bs4")
_coloredlogs_available = _is_package_available("coloredlogs")
_datasets_available = _is_package_available("datasets")
_decord_available = importlib.util.find_spec("decord") is not None
_detectron2_available = _is_package_available("detectron2")
_faiss_available = _is_package_available("faiss") or _is_package_available("faiss-cpu")
_ftfy_available = _is_package_available("ftfy")
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
_jieba_available = _is_package_available("jieba")
_kenlm_available = _is_package_available("kenlm")
_keras_nlp_available = _is_package_available("keras_nlp")
_librosa_available = _is_package_available("librosa")
_natten_available = _is_package_available("natten")
_ninja_available = _is_package_available("ninja")
_onnx_available = _is_package_available("onnx")
_openai_available = _is_package_available("openai")
_optimum_available = _is_package_available("optimum")
_optimumneuron_available = _optimum_available and _is_package_available("optimum.neuron")
_pandas_available = _is_package_available("pandas")
_peft_available = _is_package_available("peft")
_phonemizer_available = _is_package_available("phonemizer")
_psutil_available = _is_package_available("psutil")
_py3nvml_available = _is_package_available("py3nvml")
_pyctcdecode_available = _is_package_available("pyctcdecode")
_pytesseract_available = _is_package_available("pytesseract")
_pytorch_quantization_available = _is_package_available("pytorch_quantization")
_rjieba_available = _is_package_available("rjieba")
_sacremoses_available = _is_package_available("sacremoses")
_safetensors_available = _is_package_available("safetensors")
_scipy_available = _is_package_available("scipy")
_sentencepiece_available = _is_package_available("sentencepiece")
_sklearn_available = importlib.util.find_spec("sklearn") is not None
if _sklearn_available:
try:
importlib_metadata.version("scikit-learn")
except importlib_metadata.PackageNotFoundError:
_sklearn_available = False
_smdistributed_available = _is_package_available("smdistributed")
_soundfile_available = _is_package_available("soundfile")
_spacy_available = _is_package_available("spacy")
_sudachipy_available = _is_package_available("sudachipy")
_tensorflow_probability_available = _is_package_available("tensorflow_probability")
_tensorflow_text_available = _is_package_available("tensorflow_text")
_tf2onnx_available = _is_package_available("tf2onnx")
_timm_available = _is_package_available("timm")
_tokenizers_available = _is_package_available("tokenizers")
_torchaudio_available = _is_package_available("torchaudio")
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")
_torch_version = "N/A"
_torch_available = False
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec("torch") is not None
if _torch_available:
try:
_torch_version = importlib_metadata.version("torch")
logger.info(f"PyTorch version {_torch_version} available.")
except importlib_metadata.PackageNotFoundError:
_torch_available = False
_torch_available, _torch_version = _is_package_available("torch", return_version=True)
else:
logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False
_tf_version = "N/A"
_tf_available = False
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES:
_tf_available = True
else:
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None
_tf_available = _is_package_available("tensorflow")
if _tf_available:
candidates = (
"tensorflow",
@ -93,180 +165,10 @@ else:
f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum."
)
_tf_available = False
else:
logger.info(f"TensorFlow version {_tf_version} available.")
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
if _flax_available:
try:
_jax_version = importlib_metadata.version("jax")
_flax_version = importlib_metadata.version("flax")
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
except importlib_metadata.PackageNotFoundError:
_flax_available = False
else:
_flax_available = False
_datasets_available = importlib.util.find_spec("datasets") is not None
try:
# Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version
# AND checking it has an author field in the metadata that is HuggingFace.
_ = importlib_metadata.version("datasets")
_datasets_metadata = importlib_metadata.metadata("datasets")
if _datasets_metadata.get("author", "") != "HuggingFace Inc.":
_datasets_available = False
except importlib_metadata.PackageNotFoundError:
_datasets_available = False
_diffusers_available = importlib.util.find_spec("diffusers") is not None
try:
_diffusers_version = importlib_metadata.version("diffusers")
logger.debug(f"Successfully imported diffusers version {_diffusers_version}")
except importlib_metadata.PackageNotFoundError:
_diffusers_available = False
_detectron2_available = importlib.util.find_spec("detectron2") is not None
try:
_detectron2_version = importlib_metadata.version("detectron2")
logger.debug(f"Successfully imported detectron2 version {_detectron2_version}")
except importlib_metadata.PackageNotFoundError:
_detectron2_available = False
_faiss_available = importlib.util.find_spec("faiss") is not None
try:
_faiss_version = importlib_metadata.version("faiss")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib_metadata.PackageNotFoundError:
try:
_faiss_version = importlib_metadata.version("faiss-cpu")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib_metadata.PackageNotFoundError:
_faiss_available = False
_ftfy_available = importlib.util.find_spec("ftfy") is not None
try:
_ftfy_version = importlib_metadata.version("ftfy")
logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
except importlib_metadata.PackageNotFoundError:
_ftfy_available = False
coloredlogs = importlib.util.find_spec("coloredlogs") is not None
try:
_coloredlogs_available = importlib_metadata.version("coloredlogs")
logger.debug(f"Successfully imported sympy version {_coloredlogs_available}")
except importlib_metadata.PackageNotFoundError:
_coloredlogs_available = False
sympy_available = importlib.util.find_spec("sympy") is not None
try:
_sympy_available = importlib_metadata.version("sympy")
logger.debug(f"Successfully imported sympy version {_sympy_available}")
except importlib_metadata.PackageNotFoundError:
_sympy_available = False
_tf2onnx_available = importlib.util.find_spec("tf2onnx") is not None
try:
_tf2onnx_version = importlib_metadata.version("tf2onnx")
logger.debug(f"Successfully imported tf2onnx version {_tf2onnx_version}")
except importlib_metadata.PackageNotFoundError:
_tf2onnx_available = False
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
try:
_onxx_version = importlib_metadata.version("onnx")
logger.debug(f"Successfully imported onnx version {_onxx_version}")
except importlib_metadata.PackageNotFoundError:
_onnx_available = False
_opencv_available = importlib.util.find_spec("cv2") is not None
_pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None
try:
_pytorch_quantization_version = importlib_metadata.version("pytorch_quantization")
logger.debug(f"Successfully imported pytorch-quantization version {_pytorch_quantization_version}")
except importlib_metadata.PackageNotFoundError:
_pytorch_quantization_available = False
_soundfile_available = importlib.util.find_spec("soundfile") is not None
try:
_soundfile_version = importlib_metadata.version("soundfile")
logger.debug(f"Successfully imported soundfile version {_soundfile_version}")
except importlib_metadata.PackageNotFoundError:
_soundfile_available = False
_tensorflow_probability_available = importlib.util.find_spec("tensorflow_probability") is not None
try:
_tensorflow_probability_version = importlib_metadata.version("tensorflow_probability")
logger.debug(f"Successfully imported tensorflow-probability version {_tensorflow_probability_version}")
except importlib_metadata.PackageNotFoundError:
_tensorflow_probability_available = False
_timm_available = importlib.util.find_spec("timm") is not None
try:
_timm_version = importlib_metadata.version("timm")
logger.debug(f"Successfully imported timm version {_timm_version}")
except importlib_metadata.PackageNotFoundError:
_timm_available = False
_natten_available = importlib.util.find_spec("natten") is not None
try:
_natten_version = importlib_metadata.version("natten")
logger.debug(f"Successfully imported natten version {_natten_version}")
except importlib_metadata.PackageNotFoundError:
_natten_available = False
_torchaudio_available = importlib.util.find_spec("torchaudio") is not None
try:
_torchaudio_version = importlib_metadata.version("torchaudio")
logger.debug(f"Successfully imported torchaudio version {_torchaudio_version}")
except importlib_metadata.PackageNotFoundError:
_torchaudio_available = False
_phonemizer_available = importlib.util.find_spec("phonemizer") is not None
try:
_phonemizer_version = importlib_metadata.version("phonemizer")
logger.debug(f"Successfully imported phonemizer version {_phonemizer_version}")
except importlib_metadata.PackageNotFoundError:
_phonemizer_available = False
_pyctcdecode_available = importlib.util.find_spec("pyctcdecode") is not None
try:
_pyctcdecode_version = importlib_metadata.version("pyctcdecode")
logger.debug(f"Successfully imported pyctcdecode version {_pyctcdecode_version}")
except importlib_metadata.PackageNotFoundError:
_pyctcdecode_available = False
_librosa_available = importlib.util.find_spec("librosa") is not None
try:
_librosa_version = importlib_metadata.version("librosa")
logger.debug(f"Successfully imported librosa version {_librosa_version}")
except importlib_metadata.PackageNotFoundError:
_librosa_available = False
ccl_version = "N/A"
_is_ccl_available = (
importlib.util.find_spec("torch_ccl") is not None
@ -274,38 +176,46 @@ _is_ccl_available = (
)
try:
ccl_version = importlib_metadata.version("oneccl_bind_pt")
logger.debug(f"Successfully imported oneccl_bind_pt version {ccl_version}")
logger.debug(f"Detected oneccl_bind_pt version {ccl_version}")
except importlib_metadata.PackageNotFoundError:
_is_ccl_available = False
_decord_availale = importlib.util.find_spec("decord") is not None
try:
_decord_version = importlib_metadata.version("decord")
logger.debug(f"Successfully imported decord version {_decord_version}")
except importlib_metadata.PackageNotFoundError:
_decord_availale = False
_jieba_available = importlib.util.find_spec("jieba") is not None
try:
_jieba_version = importlib_metadata.version("jieba")
logger.debug(f"Successfully imported jieba version {_jieba_version}")
except importlib_metadata.PackageNotFoundError:
_jieba_available = False
_flax_available = False
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available, _flax_version = _is_package_available("flax", return_version=True)
if _flax_available:
_jax_available, _jax_version = _is_package_available("jax", return_version=True)
if _jax_available:
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
else:
_flax_available = _jax_available = False
_jax_version = _flax_version = "N/A"
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
_torch_fx_available = False
if _torch_available:
torch_version = version.parse(_torch_version)
_torch_fx_available = (torch_version.major, torch_version.minor) >= (
TORCH_FX_REQUIRED_VERSION.major,
TORCH_FX_REQUIRED_VERSION.minor,
)
def is_kenlm_available():
return importlib.util.find_spec("kenlm") is not None
return _kenlm_available
def is_torch_available():
return _torch_available
def get_torch_version():
return _torch_version
def is_torchvision_available():
return importlib.util.find_spec("torchvision") is not None
return _torchvision_available
def is_pyctcdecode_available():
@ -404,26 +314,16 @@ def is_torch_tf32_available():
return True
torch_version = None
_torch_fx_available = False
if _torch_available:
torch_version = version.parse(importlib_metadata.version("torch"))
_torch_fx_available = (torch_version.major, torch_version.minor) >= (
TORCH_FX_REQUIRED_VERSION.major,
TORCH_FX_REQUIRED_VERSION.minor,
)
def is_torch_fx_available():
return _torch_fx_available
def is_peft_available():
return importlib.util.find_spec("peft") is not None
return _peft_available
def is_bs4_available():
return importlib.util.find_spec("bs4") is not None
return _bs4_available
def is_tf_available():
@ -443,7 +343,7 @@ def is_onnx_available():
def is_openai_available():
return importlib.util.find_spec("openai") is not None
return _openai_available
def is_flax_available():
@ -517,40 +417,36 @@ def is_detectron2_available():
def is_rjieba_available():
return importlib.util.find_spec("rjieba") is not None
return _rjieba_available
def is_psutil_available():
return importlib.util.find_spec("psutil") is not None
return _psutil_available
def is_py3nvml_available():
return importlib.util.find_spec("py3nvml") is not None
return _py3nvml_available
def is_sacremoses_available():
return importlib.util.find_spec("sacremoses") is not None
return _sacremoses_available
def is_apex_available():
return importlib.util.find_spec("apex") is not None
return _apex_available
def is_ninja_available():
return importlib.util.find_spec("ninja") is not None
return _ninja_available
def is_ipex_available():
def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
if not is_torch_available() or importlib.util.find_spec("intel_extension_for_pytorch") is None:
return False
_ipex_version = "N/A"
try:
_ipex_version = importlib_metadata.version("intel_extension_for_pytorch")
except importlib_metadata.PackageNotFoundError:
if not is_torch_available() or not _ipex_available:
return False
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
if torch_major_and_minor != ipex_major_and_minor:
@ -563,11 +459,11 @@ def is_ipex_available():
def is_bitsandbytes_available():
return importlib.util.find_spec("bitsandbytes") is not None
return _bitsandbytes_available
def is_torchdistx_available():
return importlib.util.find_spec("torchdistx") is not None
return _torchdistx_available
def is_faiss_available():
@ -575,17 +471,15 @@ def is_faiss_available():
def is_scipy_available():
return importlib.util.find_spec("scipy") is not None
return _scipy_available
def is_sklearn_available():
if importlib.util.find_spec("sklearn") is None:
return False
return is_scipy_available() and importlib.util.find_spec("sklearn.metrics")
return _sklearn_available
def is_sentencepiece_available():
return importlib.util.find_spec("sentencepiece") is not None
return _sentencepiece_available
def is_protobuf_available():
@ -595,56 +489,54 @@ def is_protobuf_available():
def is_accelerate_available(check_partial_state=False):
accelerate_available = importlib.util.find_spec("accelerate") is not None
if accelerate_available:
if check_partial_state:
return version.parse(importlib_metadata.version("accelerate")) >= version.parse("0.17.0")
else:
return True
else:
return False
if check_partial_state:
return _accelerate_available and version.parse(_accelerate_version) >= version.parse("0.17.0")
return _accelerate_available
def is_optimum_available():
return importlib.util.find_spec("optimum") is not None
return _optimum_available
def is_optimum_neuron_available():
return importlib.util.find_spec("optimum.neuron") is not None
return _optimumneuron_available
def is_safetensors_available():
if is_torch_available():
if version.parse(_torch_version) >= version.parse("1.10"):
return importlib.util.find_spec("safetensors") is not None
else:
return False
else:
return importlib.util.find_spec("safetensors") is not None
if is_torch_available() and version.parse(_torch_version) < version.parse("1.10"):
return False
return _safetensors_available
def is_tokenizers_available():
return importlib.util.find_spec("tokenizers") is not None
return _tokenizers_available
def is_vision_available():
return importlib.util.find_spec("PIL") is not None
_pil_available = importlib.util.find_spec("PIL") is not None
if _pil_available:
try:
package_version = importlib_metadata.version("Pillow")
except importlib_metadata.PackageNotFoundError:
return False
logger.debug(f"Detected PIL version {package_version}")
return _pil_available
def is_pytesseract_available():
return importlib.util.find_spec("pytesseract") is not None
return _pytesseract_available
def is_spacy_available():
return importlib.util.find_spec("spacy") is not None
return _spacy_available
def is_tensorflow_text_available():
return is_tf_available() and importlib.util.find_spec("tensorflow_text") is not None
return is_tf_available() and _tensorflow_text_available
def is_keras_nlp_available():
return is_tensorflow_text_available() and importlib.util.find_spec("keras_nlp") is not None
return is_tensorflow_text_available() and _keras_nlp_available
def is_in_notebook():
@ -674,7 +566,7 @@ def is_tensorflow_probability_available():
def is_pandas_available():
return importlib.util.find_spec("pandas") is not None
return _pandas_available
def is_sagemaker_dp_enabled():
@ -688,7 +580,7 @@ def is_sagemaker_dp_enabled():
except json.JSONDecodeError:
return False
# Lastly, check if the `smdistributed` module is present.
return importlib.util.find_spec("smdistributed") is not None
return _smdistributed_available
def is_sagemaker_mp_enabled():
@ -712,7 +604,7 @@ def is_sagemaker_mp_enabled():
except json.JSONDecodeError:
return False
# Lastly, check if the `smdistributed` module is present.
return importlib.util.find_spec("smdistributed") is not None
return _smdistributed_available
def is_training_run_on_sagemaker():
@ -762,11 +654,11 @@ def is_ccl_available():
def is_decord_available():
return _decord_availale
return _decord_available
def is_sudachi_available():
return importlib.util.find_spec("sudachipy") is not None
return _sudachipy_available
def is_jumanpp_available():

View File

@ -319,12 +319,12 @@ class OnnxExportTestCaseV2(TestCase):
onnx_config = onnx_config_class_constructor(model.config)
if is_torch_available():
from transformers.utils import torch_version
from transformers.utils import get_torch_version
if torch_version < onnx_config.torch_onnx_minimum_version:
if get_torch_version() < onnx_config.torch_onnx_minimum_version:
pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
)
preprocessor = get_preprocessor(model_name)
@ -362,12 +362,12 @@ class OnnxExportTestCaseV2(TestCase):
onnx_config = onnx_config_class_constructor(model.config)
if is_torch_available():
from transformers.utils import torch_version
from transformers.utils import get_torch_version
if torch_version < onnx_config.torch_onnx_minimum_version:
if get_torch_version() < onnx_config.torch_onnx_minimum_version:
pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
)
encoder_model = model.get_encoder()