mmengine/docs/zh_cn/tutorials/optim_wrapper.md

503 lines
22 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 优化器封装OptimWrapper
在[执行器教程](./runner.md)和[模型教程](./model.md)中我们或多或少地提到了优化器封装OptimWrapper的概念但是却没有介绍为什么我们需要优化器封装相比于 Pytorch 原生的优化器,优化器封装又有怎样的优势,这些问题会在本教程中得到一一解答。我们将通过对比的方式帮助大家理解,优化器封装的优势,以及如何使用它。
优化器封装顾名思义,是 Pytorch 原生优化器Optimizer高级抽象它在增加了更多功能的同时提供了一套统一的接口。优化器封装支持不同的训练策略包括混合精度训练、梯度累加和梯度截断。我们可以根据需求选择合适的训练策略。优化器封装还定义了一套标准的参数更新流程用户可以基于这一套流程实现同一套代码不同训练策略的切换。
## 优化器封装 vs 优化器
这里我们分别基于 Pytorch 内置的优化器和 MMEngine 的优化器封装进行单精度训练、混合精度训练和梯度累加,对比二者实现上的区别。
### 训练模型
**1.1 基于 Pytorch 的 SGD 优化器实现单精度训练**
```python
import torch
from torch.optim import SGD
import torch.nn as nn
import torch.nn.functional as F
inputs = [torch.zeros(10, 1, 1)] * 10
targets = [torch.ones(10, 1, 1)] * 10
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()
for input, target in zip(inputs, targets):
output = model(input)
loss = F.l1_loss(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
```
**1.2 使用 MMEngine 的优化器封装实现单精度训练**
```python
from mmengine.optim import OptimWrapper
optim_wrapper = OptimWrapper(optimizer=optimizer)
for input, target in zip(inputs, targets):
output = model(input)
loss = F.l1_loss(output, target)
optim_wrapper.update_params(loss)
```
![image](https://user-images.githubusercontent.com/57566630/185605436-17f08083-b219-4b38-b714-eb891f7a8e56.png)
优化器封装的 `update_params` 实现了标准的梯度计算、参数更新和梯度清零流程,可以直接用来更新模型参数。
**2.1 基于 Pytorch 的 SGD 优化器实现混合精度训练**
```python
from torch.cuda.amp import autocast
model = model.cuda()
inputs = [torch.zeros(10, 1, 1, 1)] * 10
targets = [torch.ones(10, 1, 1, 1)] * 10
for input, target in zip(inputs, targets):
with autocast():
output = model(input.cuda())
loss = F.l1_loss(output, target.cuda())
loss.backward()
optimizer.step()
optimizer.zero_grad()
```
**2.2 基于 MMEngine 的 优化器封装实现混合精度训练**
```python
from mmengine.optim import AmpOptimWrapper
optim_wrapper = AmpOptimWrapper(optimizer=optimizer)
for input, target in zip(inputs, targets):
with optim_wrapper.optim_context(model):
output = model(input.cuda())
loss = F.l1_loss(output, target.cuda())
optim_wrapper.update_params(loss)
```
![image](https://user-images.githubusercontent.com/57566630/185606060-2fdebd90-c17a-4a8c-aaf1-540d47975c59.png)
开启混合精度训练需要使用 `AmpOptimWrapper`,他的 optim_context 接口类似 `autocast`,会开启混合精度训练的上下文。除此之外他还能加速分布式训练时的梯度累加,这个我们会在下一个示例中介绍。
**3.1 基于 Pytorch 的 SGD 优化器实现混合精度训练和梯度累加**
```python
for idx, (input, target) in enumerate(zip(inputs, targets)):
with autocast():
output = model(input.cuda())
loss = F.l1_loss(output, target.cuda())
loss.backward()
if idx % 2 == 0:
optimizer.step()
optimizer.zero_grad()
```
**3.2 基于 MMEngine 的优化器封装实现混合精度训练和梯度累加**
```python
optim_wrapper = AmpOptimWrapper(optimizer=optimizer, accumulative_counts=2)
for input, target in zip(inputs, targets):
with optim_wrapper.optim_context(model):
output = model(input.cuda())
loss = F.l1_loss(output, target.cuda())
optim_wrapper.update_params(loss)
```
![image](https://user-images.githubusercontent.com/57566630/185608932-91a082d4-1bf4-4329-b283-98fbbc20b5f7.png)
我们只需要配置 `accumulative_counts` 参数,并调用 `update_params` 接口就能实现梯度累加的功能。除此之外,分布式训练情况下,如果我们配置梯度累加的同时开启了 `optim_wrapper` 上下文,可以避免梯度累加阶段不必要的梯度同步。
优化器封装同样提供了更细粒度的接口,方便用户实现一些自定义的参数更新逻辑:
- `backward`:传入损失,用于计算参数梯度。
- `step`:同 `optimizer.step`,用于更新参数。
- `zero_grad`:同 `optimizer.zero_grad`,用于参数的梯度。
我们可以使用上述接口实现和 Pytorch 优化器相同的参数更新逻辑:
```python
for idx, (input, target) in enumerate(zip(inputs, targets)):
optimizer.zero_grad()
with optim_wrapper.optim_context(model):
output = model(input.cuda())
loss = F.l1_loss(output, target.cuda())
optim_wrapper.backward(loss)
if idx % 2 == 0:
optim_wrapper.step()
optim_wrapper.zero_grad()
```
我们同样可以为优化器封装配置梯度裁减策略:
```python
# 基于 torch.nn.utils.clip_grad_norm_ 对梯度进行裁减
optim_wrapper = AmpOptimWrapper(
optimizer=optimizer, clip_grad=dict(max_norm=1))
# 基于 torch.nn.utils.clip_grad_value_ 对梯度进行裁减
optim_wrapper = AmpOptimWrapper(
optimizer=optimizer, clip_grad=dict(clip_value=0.2))
```
### 获取学习率/动量
优化器封装提供了 `get_lr``get_momentum` 接口用于获取优化器的一个参数组的学习率:
```python
import torch.nn as nn
from torch.optim import SGD
from mmengine.optim import OptimWrapper
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optim_wrapper = OptimWrapper(optimizer)
print(optimizer.param_groups[0]['lr']) # 0.01
print(optimizer.param_groups[0]['momentum']) # 0
print(optim_wrapper.get_lr()) # {'lr': [0.01]}
print(optim_wrapper.get_momentum()) # {'momentum': [0]}
```
```
0.01
0
{'lr': [0.01]}
{'momentum': [0]}
```
### 导出/加载状态字典
优化器封装和优化器一样,提供了 `state_dict``load_state_dict` 接口,用于导出/加载优化器状态,对于 `AmpOptimWrapper`,优化器封装还会额外导出混合精度训练相关的参数:
```python
import torch.nn as nn
from torch.optim import SGD
from mmengine.optim import OptimWrapper, AmpOptimWrapper
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optim_wrapper = OptimWrapper(optimizer=optimizer)
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer)
# 导出状态字典
optim_state_dict = optim_wrapper.state_dict()
amp_optim_state_dict = amp_optim_wrapper.state_dict()
print(optim_state_dict)
print(amp_optim_state_dict)
optim_wrapper_new = OptimWrapper(optimizer=optimizer)
amp_optim_wrapper_new = AmpOptimWrapper(optimizer=optimizer)
# 加载状态字典
amp_optim_wrapper_new.load_state_dict(amp_optim_state_dict)
optim_wrapper_new.load_state_dict(optim_state_dict)
```
```
{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'params': [0, 1]}]}
{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'params': [0, 1]}], 'loss_scaler': {'scale': 65536.0, 'growth_factor': 2.0, 'backoff_factor': 0.5, 'growth_interval': 2000, '_growth_tracker': 0}}
```
### 使用多个优化器
考虑到生成对抗网络之类的算法通常需要使用多个优化器来训练生成器和判别器,因此优化器封装提供了优化器封装的容器类:`OptimWrapperDict` 来管理多个优化器封装。`OptimWrapperDict` 以字典的形式存储优化器封装,并允许用户像字典一样访问、遍历其中的元素,即优化器封装实例。
与普通的优化器封装不同,`OptimWrapperDict` 没有实现 `update_params``optim_context`, `backward`、`step` 等方法,无法被直接用于训练模型。我们建议直接访问 `OptimWrapperDict` 管理的优化器实例,来实现参数更新逻辑。
你或许会好奇,既然 `OptimWrapperDict` 没有训练的功能,那为什么不直接使用 `dict` 来管理多个优化器?事实上,`OptimWrapperDict` 的核心功能是支持批量导出/加载所有优化器封装的状态字典;支持获取多个优化器封装的学习率、动量。如果没有 `OptimWrapperDict``MMEngine` 就需要在很多位置对优化器封装的类型做 `if else` 判断,以获取所有优化器封装的状态。
```python
from torch.optim import SGD
import torch.nn as nn
from mmengine.optim import OptimWrapper, OptimWrapperDict
gen = nn.Linear(1, 1)
disc = nn.Linear(1, 1)
optimizer_gen = SGD(gen.parameters(), lr=0.01)
optimizer_disc = SGD(disc.parameters(), lr=0.01)
optim_wapper_gen = OptimWrapper(optimizer=optimizer_gen)
optim_wapper_disc = OptimWrapper(optimizer=optimizer_disc)
optim_dict = OptimWrapperDict(gen=optim_wapper_gen, disc=optim_wapper_disc)
print(optim_dict.get_lr()) # {'gen.lr': [0.01], 'disc.lr': [0.01]}
print(optim_dict.get_momentum()) # {'gen.momentum': [0], 'disc.momentum': [0]}
```
```
{'gen.lr': [0.01], 'disc.lr': [0.01]}
{'gen.momentum': [0], 'disc.momentum': [0]}
```
如上例所示,`OptimWrapperDict` 可以非常方便的导出所有优化器封装的学习率和动量,同样的,优化器封装也能够导出/加载所有优化器封装的状态字典。
### 在[执行器](./runner.md)中配置优化器封装
优化器封装需要接受 `optimizer` 参数,因此我们首先需要为优化器封装配置 `optimizer`。MMEngine 会自动将 PyTorch 中的所有优化器都添加进 `OPTIMIZERS` 注册表中,用户可以用字典的形式来指定优化器,所有支持的优化器见 [PyTorch 优化器列表](https://pytorch.org/docs/stable/optim.html#algorithms)。
以配置一个 SGD 优化器封装为例:
```python
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
```
这样我们就配置好了一个优化器类型为 SGD 的优化器封装,学习率、动量等参数如配置所示。考虑到 `OptimWrapper` 为标准的单精度训练,因此我们也可以不配置 `type` 字段:
```python
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(optimizer=optimizer)
```
要想开启混合精度训练和梯度累加,需要将 `type` 切换成 `AmpOptimWrapper`,并指定 `accumulative_counts` 参数:
```python
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(type='AmpOptimWrapper', optimizer=optimizer, accumulative_counts=2)
```
```{note}
如果你是第一次阅读 MMEngine 的教程文档,并且尚未了解[配置类](../advanced_tutorials/config.md)、[注册器](../advanced_tutorials/registry.md) 等概念,建议可以先跳过以下进阶教程,先去阅读其他文档。当然了,如果你已经具备了这些储备知识,我们强烈建议阅读进阶教程,在进阶教程中,我们将学会:
1. 如何在配置文件中定制化地在优化器中配置模型参数的学习率、衰减系数等。
2. 如何自定义一个优化器构造策略,实现真正意义上的“优化器配置自由”。
除了配置类和注册器等前置知识,我们建议在开始进阶教程之前,先深入了解 Pytorch 原生优化器构造时的 params 参数。
```
## 进阶配置
PyTorch 的优化器支持对模型中的不同参数设置不同的超参数例如对一个分类模型的骨干backbone和分类头head设置不同的学习率
```python
from torch.optim import SGD
import torch.nn as nn
model = nn.ModuleDict(dict(backbone=nn.Linear(1, 1), head=nn.Linear(1, 1)))
optimizer = SGD([{'params': model.backbone.parameters()},
{'params': model.head.parameters(), 'lr': 1e-3}],
lr=0.01,
momentum=0.9)
```
上面的例子中,模型的骨干部分使用了 0.01 学习率,而模型的头部则使用了 1e-3 学习率。用户可以将模型的不同部分参数和对应的超参组成一个字典的列表传给优化器,来实现对模型优化的细粒度调整。
在 MMEngine 中我们通过优化器封装构造器optimizer wrapper constructor让用户能够直接通过设置优化器封装配置文件中的 `paramwise_cfg` 字段而非修改代码来实现对模型的不同部分设置不同的超参。
### 为不同类型的参数设置不同的超参系数
MMEngine 提供的默认优化器封装构造器支持对模型中不同类型的参数设置不同的超参系数。例如,我们可以在 `paramwise_cfg` 中设置 `norm_decay_mult=0`从而将正则化层normalization layer的权重weight和偏置bias的权值衰减系数weight decay设置为 0来实现 [Bag of Tricks](https://arxiv.org/abs/1812.01187) 论文中提到的不对正则化层进行权值衰减的技巧。
具体示例如下,我们将 `ToyModel` 中所有正则化层(`head.bn`)的权重衰减系数设置为 0
```python
from mmengine.optim import build_optim_wrapper
from collections import OrderedDict
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = nn.ModuleDict(
dict(layer0=nn.Linear(1, 1), layer1=nn.Linear(1, 1)))
self.head = nn.Sequential(
OrderedDict(
linear=nn.Linear(1, 1),
bn=nn.BatchNorm1d(1)))
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.01, weight_decay=0.0001),
paramwise_cfg=dict(norm_decay_mult=0))
optimizer = build_optim_wrapper(ToyModel(), optim_wrapper)
```
```
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.bias:lr=0.01
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.bias:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.bias:lr=0.01
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.bias:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.bias:lr=0.01
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.bias:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.weight:weight_decay=0.0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.bias:weight_decay=0.0
```
除了可以对正则化层的权重衰减进行配置外MMEngine 的默认优化器封装构造器的 `paramwise_cfg` 还支持对更多不同类型的参数设置超参系数,支持的配置如下:
`lr_mult`:所有参数的学习率系数
`decay_mult`:所有参数的衰减系数
`bias_lr_mult`:偏置的学习率系数(不包括正则化层的偏置以及可变形卷积的 offset
`bias_decay_mult`:偏置的权值衰减系数(不包括正则化层的偏置以及可变形卷积的 offset
`norm_decay_mult`:正则化层权重和偏置的权值衰减系数
`flat_decay_mult`:一维参数的权值衰减系数
`dwconv_decay_mult`Depth-wise 卷积的权值衰减系数
`bypass_duplicate`:是否跳过重复的参数,默认为 `False`
`dcn_offset_lr_mult`可变形卷积Deformable Convolution的学习率系数
### 为模型不同部分的参数设置不同的超参系数
此外,与上文 PyTorch 的示例一样,在 MMEngine 中我们也同样可以对模型中的任意模块设置不同的超参,只需要在 `paramwise_cfg` 中设置 `custom_keys` 即可。
例如我们想将 `backbone.layer0` 所有参数的学习率设置为 0衰减系数设置为 0`backbone` 其余子模块的学习率设置为 0.01`head` 所有参数的学习率设置为 0.001,可以这样配置:
```python
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.01, weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={
'backbone.layer0': dict(lr_mult=0, decay_mult=0),
'backbone': dict(lr_mult=1),
'head': dict(lr_mult=0.1)
}))
optimizer = build_optim_wrapper(ToyModel(), optim_wrapper)
```
```
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.weight:lr=0.0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.weight:weight_decay=0.0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.weight:lr_mult=0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.weight:decay_mult=0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.bias:lr=0.0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.bias:weight_decay=0.0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.bias:lr_mult=0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer0.bias:decay_mult=0
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.weight:lr=0.01
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.weight:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.weight:lr_mult=1
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.bias:lr=0.01
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.bias:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- backbone.layer1.bias:lr_mult=1
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.weight:lr=0.001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.weight:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.weight:lr_mult=0.1
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.bias:lr=0.001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.bias:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.linear.bias:lr_mult=0.1
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.weight:lr=0.001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.weight:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.weight:lr_mult=0.1
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.bias:lr=0.001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.bias:weight_decay=0.0001
08/23 22:02:43 - mmengine - INFO - paramwise_options -- head.bn.bias:lr_mult=0.1
```
上例中,模型的状态字典的 `key` 如下:
```python
for name, val in ToyModel().named_parameters():
print(name)
```
```
backbone.layer0.weight
backbone.layer0.bias
backbone.layer1.weight
backbone.layer1.bias
head.linear.weight
head.linear.bias
head.bn.weight
head.bn.bias
```
custom_keys 中每一个字段的含义如下:
1. `'backbone': dict(lr_mult=1)`:将名字前缀为 `backbone` 的参数的学习率系数设置为 1
2. `'backbone.layer0': dict(lr_mult=0, decay_mult=0)`:将名字前缀为 `backbone.layer0` 的参数学习率系数设置为 0衰减系数设置为 0该配置优先级比第一条高
3. `'head': dict(lr_mult=0.1)`:将名字前缀为 `head` 的参数的学习率系数设置为 0.1
### 自定义优化器构造策略
与 MMEngine 中的其他模块一样,优化器封装构造器也同样由[注册表](../advanced_tutorials/registry.md)管理。我们可以通过实现自定义的优化器封装构造器来实现自定义的超参设置策略。
例如,我们想实现一个叫做 `LayerDecayOptimWrapperConstructor` 的优化器封装构造器,能够对模型不同深度的层自动设置递减的学习率:
```python
from mmengine.optim import DefaultOptimWrapperConstructor
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS
from mmengine.logging import print_log
@OPTIM_WRAPPER_CONSTRUCTORS.register_module(force=True)
class LayerDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
super().__init__(optim_wrapper_cfg, paramwise_cfg=None)
self.decay_factor = paramwise_cfg.get('decay_factor', 0.5)
super().__init__(optim_wrapper_cfg, paramwise_cfg)
def add_params(self, params, module, prefix='' ,lr=None):
if lr is None:
lr = self.base_lr
for name, param in module.named_parameters(recurse=False):
param_group = dict()
param_group['params'] = [param]
param_group['lr'] = lr
params.append(param_group)
full_name = f'{prefix}.{name}' if prefix else name
print_log(f'{full_name} : lr={lr}', logger='current')
for name, module in module.named_children():
chiled_prefix = f'{prefix}.{name}' if prefix else name
self.add_params(
params, module, chiled_prefix, lr=lr * self.decay_factor)
class ToyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer = nn.ModuleDict(dict(linear=nn.Linear(1, 1)))
self.linear = nn.Linear(1, 1)
model = ToyModel()
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.01, weight_decay=0.0001),
paramwise_cfg=dict(decay_factor=0.5),
constructor='LayerDecayOptimWrapperConstructor')
optimizer = build_optim_wrapper(model, optim_wrapper)
```
```
08/23 22:20:26 - mmengine - INFO - layer.linear.weight : lr=0.0025
08/23 22:20:26 - mmengine - INFO - layer.linear.bias : lr=0.0025
08/23 22:20:26 - mmengine - INFO - linear.weight : lr=0.005
08/23 22:20:26 - mmengine - INFO - linear.bias : lr=0.005
```
`add_params` 被第一次调用时,`params` 参数为空列表(`list``module` 为模型(`model`)。详细的重载规则参考[优化器封装构造器文档](mmengine.optim.DefaultOptimWrapperConstructor)。
类似地,如果想构造多个优化器,也需要实现自定义的构造器:
```python
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class MultipleOptimiWrapperConstructor:
...
```
### 在训练过程中调整超参
优化器中的超参数在构造时只能设置为一个定值,仅仅使用优化器封装,并不能在训练过程中调整学习率等参数。在 MMEngine 中我们实现了参数调度器Parameter Scheduler以便能够在训练过程中调整参数。关于参数调度器的用法请见[优化器参数调整策略](./param_scheduler.md)