Remove redundant backend checks in training_args.py (#30999)

* Remove backend checks in training_args.py

* Expilicit initialize the device

---------

Co-authored-by: tonghengwen <tonghengwen@cambricon.com>
This commit is contained in:
Hengwen Tong 2024-05-28 17:52:47 +08:00 committed by GitHub
parent dd4654eab7
commit 537deb7869
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 33 deletions

View File

@ -67,7 +67,7 @@ if is_torch_available():
import torch
import torch.distributed as dist
from .pytorch_utils import is_torch_greater_or_equal_than_2_0, is_torch_greater_or_equal_than_2_3
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
@ -1677,38 +1677,9 @@ class TrainingArguments:
)
self.accelerator_config.split_batches = self.split_batches
if (
self.framework == "pt"
and is_torch_available()
and (self.device.type == "cpu" and not is_torch_greater_or_equal_than_2_3)
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (self.fp16 or self.fp16_full_eval)
):
raise ValueError(
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
" (`--fp16_full_eval`) can only be used on CUDA or MLU devices or NPU devices or certain XPU devices (with IPEX)."
)
if (
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (get_xla_device_type(self.device) != "TPU")
and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval)
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU, MLU or CPU/TPU/NeuronCore devices."
)
# Initialize device before we proceed
if self.framework == "pt" and is_torch_available():
self.device
if self.torchdynamo is not None:
warnings.warn(