Compare commits

...

5 Commits

Author SHA1 Message Date
Zach Mueller 91aed0be2d Rm 2024-04-10 12:01:57 -04:00
Zach Mueller 929e797a08 Bring back old behavior 2024-04-10 12:01:14 -04:00
Zach Mueller e5c9ed42a1 Clean 2024-04-10 11:16:35 -04:00
Zach Mueller f110bcbd45 Bring back 1:1 2024-04-10 11:16:22 -04:00
Zach Mueller 4a584ebe11 Allow smart context to use accelerator 2024-04-10 10:46:29 -04:00
1 changed files with 7 additions and 1 deletions

View File

@ -69,7 +69,10 @@ from .models.auto.modeling_auto import (
MODEL_MAPPING_NAMES,
)
from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from .pytorch_utils import (
ALL_LAYERNORM_LAYERS,
is_torch_greater_or_equal_than_1_13,
)
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
@ -205,6 +208,7 @@ if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version
from accelerate.utils import (
AutocastKwargs,
DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin,
@ -3094,6 +3098,8 @@ class Trainer:
"""
if self.use_cpu_amp:
ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
elif self.accelerator is not None:
ctx_manager = self.accelerator.autocast(autocast_handler=AutocastKwargs(cache_enabled=cache_enabled))
else:
ctx_manager = contextlib.nullcontext()