From 537deb7869f7711ec67aad9459b9f78ad5df5161 Mon Sep 17 00:00:00 2001 From: Hengwen Tong Date: Tue, 28 May 2024 17:52:47 +0800 Subject: [PATCH] Remove redundant backend checks in training_args.py (#30999) * Remove backend checks in training_args.py * Expilicit initialize the device --------- Co-authored-by: tonghengwen --- src/transformers/training_args.py | 37 ++++--------------------------- 1 file changed, 4 insertions(+), 33 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 7c5d4b1c73..a97139a07b 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -67,7 +67,7 @@ if is_torch_available(): import torch import torch.distributed as dist - from .pytorch_utils import is_torch_greater_or_equal_than_2_0, is_torch_greater_or_equal_than_2_3 + from .pytorch_utils import is_torch_greater_or_equal_than_2_0 if is_accelerate_available(): from accelerate.state import AcceleratorState, PartialState @@ -1677,38 +1677,9 @@ class TrainingArguments: ) self.accelerator_config.split_batches = self.split_batches - if ( - self.framework == "pt" - and is_torch_available() - and (self.device.type == "cpu" and not is_torch_greater_or_equal_than_2_3) - and (self.device.type != "cuda") - and (self.device.type != "mlu") - and (self.device.type != "npu") - and (self.device.type != "xpu") - and (get_xla_device_type(self.device) not in ["GPU", "CUDA"]) - and (self.fp16 or self.fp16_full_eval) - ): - raise ValueError( - "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation" - " (`--fp16_full_eval`) can only be used on CUDA or MLU devices or NPU devices or certain XPU devices (with IPEX)." - ) - - if ( - self.framework == "pt" - and is_torch_available() - and (self.device.type != "cuda") - and (self.device.type != "mlu") - and (self.device.type != "npu") - and (self.device.type != "xpu") - and (get_xla_device_type(self.device) not in ["GPU", "CUDA"]) - and (get_xla_device_type(self.device) != "TPU") - and (self.device.type != "cpu") - and (self.bf16 or self.bf16_full_eval) - ): - raise ValueError( - "BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation" - " (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU, MLU or CPU/TPU/NeuronCore devices." - ) + # Initialize device before we proceed + if self.framework == "pt" and is_torch_available(): + self.device if self.torchdynamo is not None: warnings.warn(