forked from TensorLayer/tensorlayer3
154 lines
5.8 KiB
Python
154 lines
5.8 KiB
Python
import os
|
|
os.environ['TL_BACKEND'] = 'tensorflow'
|
|
import time
|
|
import multiprocessing
|
|
import tensorflow as tf
|
|
from tensorlayer.models import TrainOneStep
|
|
from tensorlayer.layers import Module
|
|
import tensorlayer as tl
|
|
from torchsummary import summary
|
|
from tensorlayer import logging
|
|
from tensorlayer.files import (assign_weights, maybe_download_and_extract)
|
|
from tensorlayer.layers import (BatchNorm, Conv2d, Dense, Elementwise, AdaptiveMeanPool2d, MaxPool2d , MeanPool2d,Concat,Dropout)
|
|
from tensorlayer.layers import Module, SequentialLayer
|
|
|
|
|
|
class _DenseLayer(Module):
|
|
def __init__(self, in_channels, growth_rate, bn_size):
|
|
super(_DenseLayer, self).__init__()
|
|
W_init = tl.initializers.truncated_normal(stddev=5e-2)
|
|
W_init2 = tl.initializers.truncated_normal(stddev=0.04)
|
|
b_init2 = tl.initializers.constant(value=0.1)
|
|
self.layer_list = []
|
|
self.layer_list.append(Conv2d(bn_size * growth_rate,(1,1),in_channels=in_channels,W_init=W_init))
|
|
self.layer_list.append(BatchNorm(num_features=bn_size * growth_rate,act='relu'))
|
|
self.layer_list.append(Conv2d(growth_rate, (3, 3), in_channels=bn_size * growth_rate,W_init=W_init))
|
|
self.layer_list.append(BatchNorm(num_features=growth_rate, act='relu'))
|
|
self.dense_layer = SequentialLayer(self.layer_list)
|
|
self.concat = Concat(1)
|
|
|
|
# 重载forward函数
|
|
def forward(self, x):
|
|
new_features = self.dense_layer(x)
|
|
return self.concat([x, new_features])
|
|
|
|
|
|
class _DenseBlock(Module):
|
|
def __init__(self, num_layers, in_channels, bn_size, growth_rate):
|
|
super(_DenseBlock, self).__init__()
|
|
W_init = tl.initializers.truncated_normal(stddev=5e-2)
|
|
W_init2 = tl.initializers.truncated_normal(stddev=0.04)
|
|
b_init2 = tl.initializers.constant(value=0.1)
|
|
self.layer_list = []
|
|
for i in range(num_layers):
|
|
self.layer_list.append(_DenseLayer(in_channels + growth_rate * i,growth_rate, bn_size))
|
|
self.dense_block = SequentialLayer(self.layer_list)
|
|
|
|
# 重载forward函数
|
|
def forward(self, x):
|
|
return self.dense_block(x)
|
|
|
|
|
|
class _Transition(Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super(_Transition, self).__init__()
|
|
W_init = tl.initializers.truncated_normal(stddev=5e-2)
|
|
W_init2 = tl.initializers.truncated_normal(stddev=0.04)
|
|
b_init2 = tl.initializers.constant(value=0.1)
|
|
self.layer_list = []
|
|
self.layer_list.append(Conv2d(out_channels,(1,1),in_channels=in_channels,W_init=W_init))
|
|
self.layer_list.append(BatchNorm(num_features=out_channels,act='relu'))
|
|
self.layer_list.append(MeanPool2d((2,2),strides=(2,2)))
|
|
self.transition_layer = SequentialLayer(self.layer_list)
|
|
|
|
# 重载forward函数
|
|
def forward(self, x):
|
|
return self.transition_layer(x)
|
|
|
|
class DenseNet_BC(Module):
|
|
def __init__(self, growth_rate=12, block_config=(6, 12, 24, 16),
|
|
bn_size=4, theta=0.5, num_classes=10):
|
|
super(DenseNet_BC, self).__init__()
|
|
W_init = tl.initializers.truncated_normal(stddev=5e-2)
|
|
W_init2 = tl.initializers.truncated_normal(stddev=0.04)
|
|
b_init2 = tl.initializers.constant(value=0.1)
|
|
# 初始的卷积为filter:2倍的growth_rate
|
|
num_init_feature = 2 * growth_rate
|
|
self.layer_list = []
|
|
# 表示cifar-10
|
|
if num_classes == 10:
|
|
self.layer_list.append(Conv2d(num_init_feature,(3,3),strides=(1,1),in_channels=3,W_init=W_init))
|
|
|
|
else:
|
|
self.layer_list.append(Conv2d(num_init_feature,(7,7),strides=(2,2),padding="valid",in_channels=3,W_init=W_init))
|
|
self.layer_list.append(BatchNorm(num_features=num_init_feature,act='relu'))
|
|
self.layer_list.append(MaxPool2d((3, 3), strides=(2, 2)))
|
|
|
|
|
|
num_feature = num_init_feature
|
|
for i, num_layers in enumerate(block_config):
|
|
self.layer_list.append( _DenseBlock(num_layers, num_feature,bn_size, growth_rate))
|
|
num_feature = num_feature + growth_rate * num_layers
|
|
if i != len(block_config) - 1:
|
|
self.layer_list.append(_Transition(num_feature,int(num_feature * theta)))
|
|
num_feature = int(num_feature * theta)
|
|
|
|
self.layer_list.append(BatchNorm(num_features=num_feature,act='relu'))
|
|
self.layer_list.append(AdaptiveMeanPool2d((1,1)))
|
|
|
|
self.features = SequentialLayer(self.layer_list)
|
|
self.classifier = Dense(num_feature, num_classes,W_init=W_init2,b_init=b_init2)
|
|
|
|
|
|
def forward(self, x):
|
|
features = self.features(x)
|
|
out = features.view(features.size(0), -1)
|
|
out = self.classifier(out)
|
|
return out
|
|
|
|
|
|
# DenseNet_BC for ImageNet
|
|
def DenseNet121():
|
|
return DenseNet_BC(growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000)
|
|
|
|
|
|
def DenseNet169():
|
|
return DenseNet_BC(growth_rate=32, block_config=(6, 12, 32, 32), num_classes=1000)
|
|
|
|
|
|
def DenseNet201():
|
|
return DenseNet_BC(growth_rate=32, block_config=(6, 12, 48, 32), num_classes=1000)
|
|
|
|
|
|
def DenseNet161():
|
|
return DenseNet_BC(growth_rate=48, block_config=(6, 12, 36, 24), num_classes=1000, )
|
|
|
|
|
|
# DenseNet_BC for cifar
|
|
def densenet_BC_100():
|
|
return DenseNet_BC(growth_rate=12, block_config=(16, 16, 16))
|
|
|
|
def builddensenet(name = "densenet-100"):
|
|
if name == "densenet-100":
|
|
return densenet_BC_100()
|
|
elif name == "densenet-121":
|
|
return DenseNet121()
|
|
else:
|
|
print("not found the net")
|
|
exit(0)
|
|
|
|
def test():
|
|
net = densenet_BC_100()
|
|
print(summary(net, input_size=(3, 32, 32)))
|
|
|
|
#x = torch.randn(2, 3, 32, 32)
|
|
# y = net(x)
|
|
# print(y.size())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test()
|
|
|
|
|
|
|