Added passing parameters to "reduce_lr_on_plateau" scheduler (#27860)
This commit is contained in:
parent
56be5e80e6
commit
fe8d1302c7
|
@ -53,19 +53,22 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
|||
return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch)
|
||||
|
||||
|
||||
def get_reduce_on_plateau_schedule(optimizer: Optimizer):
|
||||
def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs):
|
||||
"""
|
||||
Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
|
||||
|
||||
Args:
|
||||
optimizer ([`~torch.optim.Optimizer`]):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
kwargs (`dict`, *optional*):
|
||||
Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau`
|
||||
for possible parameters.
|
||||
|
||||
Return:
|
||||
`torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
|
||||
"""
|
||||
|
||||
return ReduceLROnPlateau(optimizer)
|
||||
return ReduceLROnPlateau(optimizer, **kwargs)
|
||||
|
||||
|
||||
def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
|
||||
|
@ -359,9 +362,15 @@ def get_scheduler(
|
|||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU:
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer)
|
||||
|
||||
if scheduler_specific_kwargs is None:
|
||||
scheduler_specific_kwargs = {}
|
||||
|
||||
if name == SchedulerType.REDUCE_ON_PLATEAU:
|
||||
return schedule_func(optimizer, **scheduler_specific_kwargs)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
@ -376,9 +385,6 @@ def get_scheduler(
|
|||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
|
||||
if scheduler_specific_kwargs is None:
|
||||
scheduler_specific_kwargs = {}
|
||||
|
||||
return schedule_func(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
|
|
Loading…
Reference in New Issue