mirror of https://github.com/open-mmlab/mmpose
62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch.nn as nn
|
|
from mmengine import build_from_cfg
|
|
|
|
from mmpose.registry import Registry
|
|
|
|
TRANSFORMER = Registry('Transformer')
|
|
BACKBONES = Registry('BACKBONES')
|
|
POSITIONAL_ENCODING = Registry('position encoding')
|
|
LINEAR_LAYERS = Registry('linear layers')
|
|
|
|
|
|
def build_positional_encoding(cfg, default_args=None):
|
|
"""Build backbone."""
|
|
return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
|
|
|
|
|
|
def build_backbone(cfg, default_args=None):
|
|
"""Build backbone."""
|
|
return build_from_cfg(cfg, BACKBONES, default_args)
|
|
|
|
|
|
def build_transformer(cfg, default_args=None):
|
|
"""Builder for Transformer."""
|
|
return build_from_cfg(cfg, TRANSFORMER, default_args)
|
|
|
|
|
|
LINEAR_LAYERS.register_module('Linear', module=nn.Linear)
|
|
|
|
|
|
def build_linear_layer(cfg, *args, **kwargs):
|
|
"""Build linear layer.
|
|
Args:
|
|
cfg (None or dict): The linear layer config, which should contain:
|
|
- type (str): Layer type.
|
|
- layer args: Args needed to instantiate an linear layer.
|
|
args (argument list): Arguments passed to the `__init__`
|
|
method of the corresponding linear layer.
|
|
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
|
|
method of the corresponding linear layer.
|
|
Returns:
|
|
nn.Module: Created linear layer.
|
|
"""
|
|
if cfg is None:
|
|
cfg_ = dict(type='Linear')
|
|
else:
|
|
if not isinstance(cfg, dict):
|
|
raise TypeError('cfg must be a dict')
|
|
if 'type' not in cfg:
|
|
raise KeyError('the cfg dict must contain the key "type"')
|
|
cfg_ = cfg.copy()
|
|
|
|
layer_type = cfg_.pop('type')
|
|
if layer_type not in LINEAR_LAYERS:
|
|
raise KeyError(f'Unrecognized linear type {layer_type}')
|
|
else:
|
|
linear_layer = LINEAR_LAYERS.get(layer_type)
|
|
|
|
layer = linear_layer(*args, **kwargs, **cfg_)
|
|
|
|
return layer
|