Added passing parameters to "reduce_lr_on_plateau" scheduler (#27860)

This commit is contained in:
Charbel Abi Daher 2023-12-08 14:06:10 +01:00 committed by GitHub
parent 56be5e80e6
commit fe8d1302c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 6 deletions

View File

@ -53,19 +53,22 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch)
def get_reduce_on_plateau_schedule(optimizer: Optimizer):
def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs):
"""
Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
kwargs (`dict`, *optional*):
Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau`
for possible parameters.
Return:
`torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
"""
return ReduceLROnPlateau(optimizer)
return ReduceLROnPlateau(optimizer, **kwargs)
def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
@ -359,9 +362,15 @@ def get_scheduler(
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU:
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
if scheduler_specific_kwargs is None:
scheduler_specific_kwargs = {}
if name == SchedulerType.REDUCE_ON_PLATEAU:
return schedule_func(optimizer, **scheduler_specific_kwargs)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
@ -376,9 +385,6 @@ def get_scheduler(
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if scheduler_specific_kwargs is None:
scheduler_specific_kwargs = {}
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,