tensorlayer3/densenet.py

154 lines
5.8 KiB
Python
Raw Normal View History

2021-09-29 18:35:15 +08:00
import os
os.environ['TL_BACKEND'] = 'tensorflow'
import time
import multiprocessing
import tensorflow as tf
from tensorlayer.models import TrainOneStep
from tensorlayer.layers import Module
import tensorlayer as tl
from torchsummary import summary
from tensorlayer import logging
from tensorlayer.files import (assign_weights, maybe_download_and_extract)
from tensorlayer.layers import (BatchNorm, Conv2d, Dense, Elementwise, AdaptiveMeanPool2d, MaxPool2d , MeanPool2d,Concat,Dropout)
from tensorlayer.layers import Module, SequentialLayer
class _DenseLayer(Module):
def __init__(self, in_channels, growth_rate, bn_size):
super(_DenseLayer, self).__init__()
W_init = tl.initializers.truncated_normal(stddev=5e-2)
W_init2 = tl.initializers.truncated_normal(stddev=0.04)
b_init2 = tl.initializers.constant(value=0.1)
self.layer_list = []
self.layer_list.append(Conv2d(bn_size * growth_rate,(1,1),in_channels=in_channels,W_init=W_init))
self.layer_list.append(BatchNorm(num_features=bn_size * growth_rate,act='relu'))
self.layer_list.append(Conv2d(growth_rate, (3, 3), in_channels=bn_size * growth_rate,W_init=W_init))
self.layer_list.append(BatchNorm(num_features=growth_rate, act='relu'))
self.dense_layer = SequentialLayer(self.layer_list)
self.concat = Concat(1)
# 重载forward函数
def forward(self, x):
new_features = self.dense_layer(x)
return self.concat([x, new_features])
class _DenseBlock(Module):
def __init__(self, num_layers, in_channels, bn_size, growth_rate):
super(_DenseBlock, self).__init__()
W_init = tl.initializers.truncated_normal(stddev=5e-2)
W_init2 = tl.initializers.truncated_normal(stddev=0.04)
b_init2 = tl.initializers.constant(value=0.1)
self.layer_list = []
for i in range(num_layers):
self.layer_list.append(_DenseLayer(in_channels + growth_rate * i,growth_rate, bn_size))
self.dense_block = SequentialLayer(self.layer_list)
# 重载forward函数
def forward(self, x):
return self.dense_block(x)
class _Transition(Module):
def __init__(self, in_channels, out_channels):
super(_Transition, self).__init__()
W_init = tl.initializers.truncated_normal(stddev=5e-2)
W_init2 = tl.initializers.truncated_normal(stddev=0.04)
b_init2 = tl.initializers.constant(value=0.1)
self.layer_list = []
self.layer_list.append(Conv2d(out_channels,(1,1),in_channels=in_channels,W_init=W_init))
self.layer_list.append(BatchNorm(num_features=out_channels,act='relu'))
self.layer_list.append(MeanPool2d((2,2),strides=(2,2)))
self.transition_layer = SequentialLayer(self.layer_list)
# 重载forward函数
def forward(self, x):
return self.transition_layer(x)
class DenseNet_BC(Module):
def __init__(self, growth_rate=12, block_config=(6, 12, 24, 16),
bn_size=4, theta=0.5, num_classes=10):
super(DenseNet_BC, self).__init__()
W_init = tl.initializers.truncated_normal(stddev=5e-2)
W_init2 = tl.initializers.truncated_normal(stddev=0.04)
b_init2 = tl.initializers.constant(value=0.1)
# 初始的卷积为filter:2倍的growth_rate
num_init_feature = 2 * growth_rate
self.layer_list = []
# 表示cifar-10
if num_classes == 10:
self.layer_list.append(Conv2d(num_init_feature,(3,3),strides=(1,1),in_channels=3,W_init=W_init))
else:
self.layer_list.append(Conv2d(num_init_feature,(7,7),strides=(2,2),padding="valid",in_channels=3,W_init=W_init))
self.layer_list.append(BatchNorm(num_features=num_init_feature,act='relu'))
self.layer_list.append(MaxPool2d((3, 3), strides=(2, 2)))
num_feature = num_init_feature
for i, num_layers in enumerate(block_config):
self.layer_list.append( _DenseBlock(num_layers, num_feature,bn_size, growth_rate))
num_feature = num_feature + growth_rate * num_layers
if i != len(block_config) - 1:
self.layer_list.append(_Transition(num_feature,int(num_feature * theta)))
num_feature = int(num_feature * theta)
self.layer_list.append(BatchNorm(num_features=num_feature,act='relu'))
self.layer_list.append(AdaptiveMeanPool2d((1,1)))
self.features = SequentialLayer(self.layer_list)
self.classifier = Dense(num_feature, num_classes,W_init=W_init2,b_init=b_init2)
def forward(self, x):
features = self.features(x)
out = features.view(features.size(0), -1)
out = self.classifier(out)
return out
# DenseNet_BC for ImageNet
def DenseNet121():
return DenseNet_BC(growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000)
def DenseNet169():
return DenseNet_BC(growth_rate=32, block_config=(6, 12, 32, 32), num_classes=1000)
def DenseNet201():
return DenseNet_BC(growth_rate=32, block_config=(6, 12, 48, 32), num_classes=1000)
def DenseNet161():
return DenseNet_BC(growth_rate=48, block_config=(6, 12, 36, 24), num_classes=1000, )
# DenseNet_BC for cifar
def densenet_BC_100():
return DenseNet_BC(growth_rate=12, block_config=(16, 16, 16))
def builddensenet(name = "densenet-100"):
if name == "densenet-100":
return densenet_BC_100()
elif name == "densenet-121":
return DenseNet121()
else:
print("not found the net")
exit(0)
def test():
net = densenet_BC_100()
print(summary(net, input_size=(3, 32, 32)))
#x = torch.randn(2, 3, 32, 32)
# y = net(x)
# print(y.size())
if __name__ == '__main__':
test()