Enable fp16 on CPU (#30459)

* Check removing flag for torch

* LLM oops

* Getting there...

* More discoveries

* Change

* Clean up and prettify

* Logic check

* Not
This commit is contained in:
Zach Mueller 2024-04-24 15:38:52 -04:00 committed by GitHub
parent d1d94d798f
commit 5c57463bde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 3 deletions

View File

@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")

View File

@ -69,7 +69,11 @@ from .models.auto.modeling_auto import (
MODEL_MAPPING_NAMES,
)
from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from .pytorch_utils import (
ALL_LAYERNORM_LAYERS,
is_torch_greater_or_equal_than_1_13,
is_torch_greater_or_equal_than_2_3,
)
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
@ -620,7 +624,8 @@ class Trainer:
if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
if args.device == torch.device("cpu"):
if args.fp16:
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
if not is_torch_greater_or_equal_than_2_3:
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
else:
args.half_precision_backend = "cpu_amp"
logger.info(f"Using {args.half_precision_backend} half precision backend")

View File

@ -67,7 +67,7 @@ if is_torch_available():
import torch
import torch.distributed as dist
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
from .pytorch_utils import is_torch_greater_or_equal_than_2_0, is_torch_greater_or_equal_than_2_3
if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
@ -1618,6 +1618,7 @@ class TrainingArguments:
if (
self.framework == "pt"
and is_torch_available()
and (self.device.type == "cpu" and not is_torch_greater_or_equal_than_2_3)
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")