parent
51227e26ab
commit
b2e4b091f0
|
@ -565,9 +565,11 @@ class Trainer:
|
||||||
self.scaler = ShardedGradScaler()
|
self.scaler = ShardedGradScaler()
|
||||||
elif self.fsdp is not None:
|
elif self.fsdp is not None:
|
||||||
if self.amp_dtype == torch.float16:
|
if self.amp_dtype == torch.float16:
|
||||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
from torch.distributed.fsdp.sharded_grad_scaler import (
|
||||||
|
ShardedGradScaler as FSDPShardedGradScaler,
|
||||||
|
)
|
||||||
|
|
||||||
self.scaler = ShardedGradScaler()
|
self.scaler = FSDPShardedGradScaler()
|
||||||
else:
|
else:
|
||||||
self.do_grad_scaling = False
|
self.do_grad_scaling = False
|
||||||
self.use_cuda_amp = False
|
self.use_cuda_amp = False
|
||||||
|
@ -1366,6 +1368,8 @@ class Trainer:
|
||||||
transformer_cls_to_wrap = get_module_class_from_name(
|
transformer_cls_to_wrap = get_module_class_from_name(
|
||||||
model, self.args.fsdp_transformer_layer_cls_to_wrap
|
model, self.args.fsdp_transformer_layer_cls_to_wrap
|
||||||
)
|
)
|
||||||
|
if transformer_cls_to_wrap is None:
|
||||||
|
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
||||||
auto_wrap_policy = functools.partial(
|
auto_wrap_policy = functools.partial(
|
||||||
transformer_auto_wrap_policy,
|
transformer_auto_wrap_policy,
|
||||||
# Transformer layer class to wrap
|
# Transformer layer class to wrap
|
||||||
|
|
Loading…
Reference in New Issue