diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ceed806e46..027afc938a 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -39,16 +39,32 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name # TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better. def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: - # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + # Check if the package spec exists and grab its version to avoid importing a local directory package_exists = importlib.util.find_spec(pkg_name) is not None package_version = "N/A" if package_exists: try: + # Primary method to get the package version package_version = importlib.metadata.version(pkg_name) - package_exists = True except importlib.metadata.PackageNotFoundError: - package_exists = False - logger.debug(f"Detected {pkg_name} version {package_version}") + # Fallback method: Only for "torch" and versions containing "dev" + if pkg_name == "torch": + try: + package = importlib.import_module(pkg_name) + temp_version = getattr(package, "__version__", "N/A") + # Check if the version contains "dev" + if "dev" in temp_version: + package_version = temp_version + package_exists = True + else: + package_exists = False + except ImportError: + # If the package can't be imported, it's not available + package_exists = False + else: + # For packages other than "torch", don't attempt the fallback and set as not available + package_exists = False + logger.debug(f"Detected {pkg_name} version: {package_version}") if return_version: return package_exists, package_version else: