Extend import utils to cover "editable" torch versions (#29000)

* Extend import utils to cover "editable" torch versions

* Re-add type hint

* Remove whitespaces

* Double quote strings

* Update comment

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* Restore package_exists

* Revert "Restore package_exists"

This reverts commit 66fd2cd5c3.

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
bhack 2024-03-15 13:34:48 +01:00 committed by GitHub
parent 56b64bf1a5
commit f62407f788
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 20 additions and 4 deletions

View File

@ -39,16 +39,32 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
# Check if the package spec exists and grab its version to avoid importing a local directory
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
# Primary method to get the package version
package_version = importlib.metadata.version(pkg_name)
package_exists = True
except importlib.metadata.PackageNotFoundError:
package_exists = False
logger.debug(f"Detected {pkg_name} version {package_version}")
# Fallback method: Only for "torch" and versions containing "dev"
if pkg_name == "torch":
try:
package = importlib.import_module(pkg_name)
temp_version = getattr(package, "__version__", "N/A")
# Check if the version contains "dev"
if "dev" in temp_version:
package_version = temp_version
package_exists = True
else:
package_exists = False
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
logger.debug(f"Detected {pkg_name} version: {package_version}")
if return_version:
return package_exists, package_version
else: