Enable fp16 on CPU (#30459)
* Check removing flag for torch * LLM oops * Getting there... * More discoveries * Change * Clean up and prettify * Logic check * Not
This commit is contained in:
parent
d1d94d798f
commit
5c57463bde
|
@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
||||
|
||||
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
|
||||
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
|
||||
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
|
||||
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
||||
|
|
|
@ -69,7 +69,11 @@ from .models.auto.modeling_auto import (
|
|||
MODEL_MAPPING_NAMES,
|
||||
)
|
||||
from .optimization import Adafactor, get_scheduler
|
||||
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
||||
from .pytorch_utils import (
|
||||
ALL_LAYERNORM_LAYERS,
|
||||
is_torch_greater_or_equal_than_1_13,
|
||||
is_torch_greater_or_equal_than_2_3,
|
||||
)
|
||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from .trainer_callback import (
|
||||
CallbackHandler,
|
||||
|
@ -620,7 +624,8 @@ class Trainer:
|
|||
if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
|
||||
if args.device == torch.device("cpu"):
|
||||
if args.fp16:
|
||||
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
|
||||
if not is_torch_greater_or_equal_than_2_3:
|
||||
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
|
||||
else:
|
||||
args.half_precision_backend = "cpu_amp"
|
||||
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
||||
|
|
|
@ -67,7 +67,7 @@ if is_torch_available():
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_2_0, is_torch_greater_or_equal_than_2_3
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.state import AcceleratorState, PartialState
|
||||
|
@ -1618,6 +1618,7 @@ class TrainingArguments:
|
|||
if (
|
||||
self.framework == "pt"
|
||||
and is_torch_available()
|
||||
and (self.device.type == "cpu" and not is_torch_greater_or_equal_than_2_3)
|
||||
and (self.device.type != "cuda")
|
||||
and (self.device.type != "mlu")
|
||||
and (self.device.type != "npu")
|
||||
|
|
Loading…
Reference in New Issue