Add inverse sqrt learning rate scheduler (#21495)

* added inverse sqrt lr scheduler

* Updated get_scheduler in src/transformers/optimization.py

* Updated src/transformers/__init__.py

* Added inverse sqrt lr scheduler test

* Updated docs/source/en/main_classes/optimizer_schedules.mdx

* Ran style and quality scripts

* Fix get_inverse_sqrt_schedule docstring

* Comment implementation URL
This commit is contained in:
Adrian Sager La Ganga 2023-02-07 21:00:50 +01:00 committed by GitHub
parent b9af152efb
commit a3034c7004
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 54 additions and 0 deletions

View File

@ -60,6 +60,8 @@ The `.optimization` module provides:
[[autodoc]] get_polynomial_decay_schedule_with_warmup
[[autodoc]] get_inverse_sqrt_schedule
### Warmup (TensorFlow)
[[autodoc]] WarmUp

View File

@ -2588,6 +2588,7 @@ else:
"get_constant_schedule_with_warmup",
"get_cosine_schedule_with_warmup",
"get_cosine_with_hard_restarts_schedule_with_warmup",
"get_inverse_sqrt_schedule",
"get_linear_schedule_with_warmup",
"get_polynomial_decay_schedule_with_warmup",
"get_scheduler",
@ -5659,6 +5660,7 @@ if TYPE_CHECKING:
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_inverse_sqrt_schedule,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_scheduler,

View File

@ -220,6 +220,42 @@ def get_polynomial_decay_schedule_with_warmup(
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_inverse_sqrt_schedule(
optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1
):
"""
Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
timescale (`int`, *optional*, defaults to `num_warmup_steps`):
Time scale.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
# Note: this implementation is adapted from
# https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
if timescale is None:
timescale = num_warmup_steps
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
shift = timescale - num_warmup_steps
decay = 1.0 / math.sqrt((current_step + shift) / timescale)
return decay
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
TYPE_TO_SCHEDULER_FUNCTION = {
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
@ -227,6 +263,7 @@ TYPE_TO_SCHEDULER_FUNCTION = {
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
SchedulerType.CONSTANT: get_constant_schedule,
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
}
@ -263,6 +300,9 @@ def get_scheduler(
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
if name == SchedulerType.INVERSE_SQRT:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")

View File

@ -363,6 +363,7 @@ class SchedulerType(ExplicitEnum):
POLYNOMIAL = "polynomial"
CONSTANT = "constant"
CONSTANT_WITH_WARMUP = "constant_with_warmup"
INVERSE_SQRT = "inverse_sqrt"
class TrainerMemoryTracker:

View File

@ -7019,6 +7019,10 @@ def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"])
def get_inverse_sqrt_schedule(*args, **kwargs):
requires_backends(get_inverse_sqrt_schedule, ["torch"])
def get_linear_schedule_with_warmup(*args, **kwargs):
requires_backends(get_linear_schedule_with_warmup, ["torch"])

View File

@ -33,6 +33,7 @@ if is_torch_available():
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_inverse_sqrt_schedule,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
)
@ -145,6 +146,10 @@ class ScheduleInitTest(unittest.TestCase):
{**common_kwargs, "power": 2.0, "lr_end": 1e-7},
[0.0, 5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156],
),
get_inverse_sqrt_schedule: (
{"num_warmup_steps": 2},
[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
),
}
for scheduler_func, data in scheds.items():