[pl] restore lr logging behavior for glue, ner examples (#6314)
This commit is contained in:
parent
be1520d3a3
commit
0203d6517f
|
@ -245,7 +245,8 @@ class BaseTransformer(pl.LightningModule):
|
|||
|
||||
class LoggingCallback(pl.Callback):
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
||||
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
||||
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())}
|
||||
pl_module.logger.log_metrics(lrs)
|
||||
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
|
|
|
@ -44,8 +44,8 @@ class GLUETransformer(BaseTransformer):
|
|||
outputs = self(**inputs)
|
||||
loss = outputs[0]
|
||||
|
||||
# tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
|
||||
tensorboard_logs = {"loss": loss}
|
||||
lr_scheduler = self.trainer.lr_schedulers[0]["scheduler"]
|
||||
tensorboard_logs = {"loss": loss, "rate": lr_scheduler.get_last_lr()[-1]}
|
||||
return {"loss": loss, "log": tensorboard_logs}
|
||||
|
||||
def prepare_data(self):
|
||||
|
|
Loading…
Reference in New Issue