mmpose/tests/test_engine/test_schedulers/test_quadratic_warmup.py

109 lines
4.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
import torch.nn.functional as F
import torch.optim as optim
from mmengine.optim.scheduler import _ParamScheduler
from mmengine.testing import assert_allclose
from mmpose.engine.schedulers import (QuadraticWarmupLR,
QuadraticWarmupMomentum,
QuadraticWarmupParamScheduler)
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
return self.conv2(F.relu(self.conv1(x)))
class TestQuadraticWarmupScheduler(TestCase):
def setUp(self):
"""Setup the model and optimizer which are used in every test method.
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
self.model = ToyModel()
self.optimizer = optim.SGD(
self.model.parameters(), lr=0.05, momentum=0.01, weight_decay=5e-4)
def _test_scheduler_value(self,
schedulers,
targets,
epochs=10,
param_name='lr'):
if isinstance(schedulers, _ParamScheduler):
schedulers = [schedulers]
for epoch in range(epochs):
for param_group, target in zip(self.optimizer.param_groups,
targets):
print(param_group[param_name])
assert_allclose(
target[epoch],
param_group[param_name],
msg='{} is wrong in epoch {}: expected {}, got {}'.format(
param_name, epoch, target[epoch],
param_group[param_name]),
atol=1e-5,
rtol=0)
[scheduler.step() for scheduler in schedulers]
def test_quadratic_warmup_scheduler(self):
with self.assertRaises(ValueError):
QuadraticWarmupParamScheduler(self.optimizer, param_name='lr')
epochs = 10
iters = 5
warmup_factor = [pow((i + 1) / float(iters), 2) for i in range(iters)]
single_targets = [x * 0.05 for x in warmup_factor] + [0.05] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = QuadraticWarmupParamScheduler(
self.optimizer, param_name='lr', end=iters)
self._test_scheduler_value(scheduler, targets, epochs)
def test_quadratic_warmup_scheduler_convert_iterbased(self):
epochs = 10
end = 5
epoch_length = 11
iters = end * epoch_length
warmup_factor = [pow((i + 1) / float(iters), 2) for i in range(iters)]
single_targets = [x * 0.05 for x in warmup_factor] + [0.05] * (
epochs * epoch_length - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = QuadraticWarmupParamScheduler.build_iter_from_epoch(
self.optimizer,
param_name='lr',
end=end,
epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
def test_quadratic_warmup_lr(self):
epochs = 10
iters = 5
warmup_factor = [pow((i + 1) / float(iters), 2) for i in range(iters)]
single_targets = [x * 0.05 for x in warmup_factor] + [0.05] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = QuadraticWarmupLR(self.optimizer, end=iters)
self._test_scheduler_value(scheduler, targets, epochs)
def test_quadratic_warmup_momentum(self):
epochs = 10
iters = 5
warmup_factor = [pow((i + 1) / float(iters), 2) for i in range(iters)]
single_targets = [x * 0.01 for x in warmup_factor] + [0.01] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = QuadraticWarmupMomentum(self.optimizer, end=iters)
self._test_scheduler_value(
scheduler, targets, epochs, param_name='momentum')