Guard XLA version imports (#30167)
This commit is contained in:
parent
fbdb978eb5
commit
e50be9a058
|
@ -136,6 +136,7 @@ from .utils import (
|
|||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
XLA_FSDPV2_MIN_VERSION,
|
||||
PushInProgress,
|
||||
PushToHubMixin,
|
||||
can_return_loss,
|
||||
|
@ -179,8 +180,14 @@ if is_datasets_available():
|
|||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla import __version__ as XLA_VERSION
|
||||
|
||||
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
|
||||
if IS_XLA_FSDPV2_POST_2_2:
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
else:
|
||||
IS_XLA_FSDPV2_POST_2_2 = False
|
||||
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
|
@ -664,6 +671,8 @@ class Trainer:
|
|||
|
||||
self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
|
||||
if self.is_fsdp_xla_v2_enabled:
|
||||
if not IS_XLA_FSDPV2_POST_2_2:
|
||||
raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.")
|
||||
# Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
|
||||
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
|
|
|
@ -98,6 +98,7 @@ from .import_utils import (
|
|||
USE_JAX,
|
||||
USE_TF,
|
||||
USE_TORCH,
|
||||
XLA_FSDPV2_MIN_VERSION,
|
||||
DummyObject,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
|
|
|
@ -89,6 +89,7 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
|
|||
|
||||
ACCELERATE_MIN_VERSION = "0.21.0"
|
||||
FSDP_MIN_VERSION = "1.12.0"
|
||||
XLA_FSDPV2_MIN_VERSION = "2.2.0"
|
||||
|
||||
|
||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||
|
|
Loading…
Reference in New Issue