parent
51227e26ab
commit
b2e4b091f0
|
@ -565,9 +565,11 @@ class Trainer:
|
|||
self.scaler = ShardedGradScaler()
|
||||
elif self.fsdp is not None:
|
||||
if self.amp_dtype == torch.float16:
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import (
|
||||
ShardedGradScaler as FSDPShardedGradScaler,
|
||||
)
|
||||
|
||||
self.scaler = ShardedGradScaler()
|
||||
self.scaler = FSDPShardedGradScaler()
|
||||
else:
|
||||
self.do_grad_scaling = False
|
||||
self.use_cuda_amp = False
|
||||
|
@ -1366,6 +1368,8 @@ class Trainer:
|
|||
transformer_cls_to_wrap = get_module_class_from_name(
|
||||
model, self.args.fsdp_transformer_layer_cls_to_wrap
|
||||
)
|
||||
if transformer_cls_to_wrap is None:
|
||||
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
# Transformer layer class to wrap
|
||||
|
|
Loading…
Reference in New Issue