Remove hardcoded prints in Trainer (#23432)
This commit is contained in:
parent
a574de302f
commit
0f2c738207
|
@ -1105,10 +1105,10 @@ class Trainer:
|
|||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
||||
print(f"skipped {module}: {skipped/2**20}M params")
|
||||
logger.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
||||
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
print(f"skipped: {skipped/2**20}M params")
|
||||
logger.info(f"skipped: {skipped/2**20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
||||
|
|
Loading…
Reference in New Issue