Compare commits

...

5 Commits

Author SHA1 Message Date
Zach Mueller 91aed0be2d Rm 2024-04-10 12:01:57 -04:00
Zach Mueller 929e797a08 Bring back old behavior 2024-04-10 12:01:14 -04:00
Zach Mueller e5c9ed42a1 Clean 2024-04-10 11:16:35 -04:00
Zach Mueller f110bcbd45 Bring back 1:1 2024-04-10 11:16:22 -04:00
Zach Mueller 4a584ebe11 Allow smart context to use accelerator 2024-04-10 10:46:29 -04:00
1 changed files with 7 additions and 1 deletions

View File

@ -69,7 +69,10 @@ from .models.auto.modeling_auto import (
MODEL_MAPPING_NAMES, MODEL_MAPPING_NAMES,
) )
from .optimization import Adafactor, get_scheduler from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 from .pytorch_utils import (
ALL_LAYERNORM_LAYERS,
is_torch_greater_or_equal_than_1_13,
)
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
CallbackHandler, CallbackHandler,
@ -205,6 +208,7 @@ if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version from accelerate import __version__ as accelerate_version
from accelerate.utils import ( from accelerate.utils import (
AutocastKwargs,
DistributedDataParallelKwargs, DistributedDataParallelKwargs,
DistributedType, DistributedType,
GradientAccumulationPlugin, GradientAccumulationPlugin,
@ -3094,6 +3098,8 @@ class Trainer:
""" """
if self.use_cpu_amp: if self.use_cpu_amp:
ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
elif self.accelerator is not None:
ctx_manager = self.accelerator.autocast(autocast_handler=AutocastKwargs(cache_enabled=cache_enabled))
else: else:
ctx_manager = contextlib.nullcontext() ctx_manager = contextlib.nullcontext()