Extend import utils to cover "editable" torch versions (#29000)
* Extend import utils to cover "editable" torch versions
* Re-add type hint
* Remove whitespaces
* Double quote strings
* Update comment
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
* Restore package_exists
* Revert "Restore package_exists"
This reverts commit 66fd2cd5c3
.
---------
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
parent
56b64bf1a5
commit
f62407f788
|
@ -39,16 +39,32 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||
|
||||
# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
|
||||
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
|
||||
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
|
||||
# Check if the package spec exists and grab its version to avoid importing a local directory
|
||||
package_exists = importlib.util.find_spec(pkg_name) is not None
|
||||
package_version = "N/A"
|
||||
if package_exists:
|
||||
try:
|
||||
# Primary method to get the package version
|
||||
package_version = importlib.metadata.version(pkg_name)
|
||||
package_exists = True
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
package_exists = False
|
||||
logger.debug(f"Detected {pkg_name} version {package_version}")
|
||||
# Fallback method: Only for "torch" and versions containing "dev"
|
||||
if pkg_name == "torch":
|
||||
try:
|
||||
package = importlib.import_module(pkg_name)
|
||||
temp_version = getattr(package, "__version__", "N/A")
|
||||
# Check if the version contains "dev"
|
||||
if "dev" in temp_version:
|
||||
package_version = temp_version
|
||||
package_exists = True
|
||||
else:
|
||||
package_exists = False
|
||||
except ImportError:
|
||||
# If the package can't be imported, it's not available
|
||||
package_exists = False
|
||||
else:
|
||||
# For packages other than "torch", don't attempt the fallback and set as not available
|
||||
package_exists = False
|
||||
logger.debug(f"Detected {pkg_name} version: {package_version}")
|
||||
if return_version:
|
||||
return package_exists, package_version
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue