Remove redundant backend checks in training_args.py (#30999)
* Remove backend checks in training_args.py * Expilicit initialize the device --------- Co-authored-by: tonghengwen <tonghengwen@cambricon.com>
This commit is contained in:
parent
dd4654eab7
commit
537deb7869
|
@ -67,7 +67,7 @@ if is_torch_available():
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_2_0, is_torch_greater_or_equal_than_2_3
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.state import AcceleratorState, PartialState
|
||||
|
@ -1677,38 +1677,9 @@ class TrainingArguments:
|
|||
)
|
||||
self.accelerator_config.split_batches = self.split_batches
|
||||
|
||||
if (
|
||||
self.framework == "pt"
|
||||
and is_torch_available()
|
||||
and (self.device.type == "cpu" and not is_torch_greater_or_equal_than_2_3)
|
||||
and (self.device.type != "cuda")
|
||||
and (self.device.type != "mlu")
|
||||
and (self.device.type != "npu")
|
||||
and (self.device.type != "xpu")
|
||||
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
|
||||
and (self.fp16 or self.fp16_full_eval)
|
||||
):
|
||||
raise ValueError(
|
||||
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
|
||||
" (`--fp16_full_eval`) can only be used on CUDA or MLU devices or NPU devices or certain XPU devices (with IPEX)."
|
||||
)
|
||||
|
||||
if (
|
||||
self.framework == "pt"
|
||||
and is_torch_available()
|
||||
and (self.device.type != "cuda")
|
||||
and (self.device.type != "mlu")
|
||||
and (self.device.type != "npu")
|
||||
and (self.device.type != "xpu")
|
||||
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
|
||||
and (get_xla_device_type(self.device) != "TPU")
|
||||
and (self.device.type != "cpu")
|
||||
and (self.bf16 or self.bf16_full_eval)
|
||||
):
|
||||
raise ValueError(
|
||||
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
|
||||
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU, MLU or CPU/TPU/NeuronCore devices."
|
||||
)
|
||||
# Initialize device before we proceed
|
||||
if self.framework == "pt" and is_torch_available():
|
||||
self.device
|
||||
|
||||
if self.torchdynamo is not None:
|
||||
warnings.warn(
|
||||
|
|
Loading…
Reference in New Issue