Guard XLA version imports (#30167)

This commit is contained in:
Zach Mueller 2024-04-11 04:49:16 -04:00 committed by GitHub
parent fbdb978eb5
commit e50be9a058
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 2 deletions

View File

@ -136,6 +136,7 @@ from .utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
XLA_FSDPV2_MIN_VERSION,
PushInProgress,
PushToHubMixin,
can_return_loss,
@ -179,8 +180,14 @@ if is_datasets_available():
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from torch_xla import __version__ as XLA_VERSION
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
if IS_XLA_FSDPV2_POST_2_2:
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
else:
IS_XLA_FSDPV2_POST_2_2 = False
if is_sagemaker_mp_enabled():
@ -664,6 +671,8 @@ class Trainer:
self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
if self.is_fsdp_xla_v2_enabled:
if not IS_XLA_FSDPV2_POST_2_2:
raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.")
# Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
num_devices = xr.global_runtime_device_count()

View File

@ -98,6 +98,7 @@ from .import_utils import (
USE_JAX,
USE_TF,
USE_TORCH,
XLA_FSDPV2_MIN_VERSION,
DummyObject,
OptionalDependencyNotAvailable,
_LazyModule,

View File

@ -89,6 +89,7 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
ACCELERATE_MIN_VERSION = "0.21.0"
FSDP_MIN_VERSION = "1.12.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)