fix FSDP ShardedGradScaler (#18358)

renaming it
This commit is contained in:
Sourab Mangrulkar 2022-07-30 10:07:56 +05:30 committed by GitHub
parent 51227e26ab
commit b2e4b091f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 2 deletions

View File

@ -565,9 +565,11 @@ class Trainer:
self.scaler = ShardedGradScaler()
elif self.fsdp is not None:
if self.amp_dtype == torch.float16:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.sharded_grad_scaler import (
ShardedGradScaler as FSDPShardedGradScaler,
)
self.scaler = ShardedGradScaler()
self.scaler = FSDPShardedGradScaler()
else:
self.do_grad_scaling = False
self.use_cuda_amp = False
@ -1366,6 +1368,8 @@ class Trainer:
transformer_cls_to_wrap = get_module_class_from_name(
model, self.args.fsdp_transformer_layer_cls_to_wrap
)
if transformer_cls_to_wrap is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap