tensorlayer3/examples/DenseNet/densenet.py

251 lines
9.8 KiB
Python

#! /usr/bin/python
# -*- coding: utf-8 -*-
import os, sys
# set up environment variable for backend
os.environ['TL_BACKEND'] = 'tensorflow'
# import essential requirements
import tensorlayer as tl
import tensorflow as tf
from tensorlayer import logging
from tensorlayer.layers import *
__all__ = [
'DenseNet'
]
class ReLULayer(Module):
def __init__(self):
super(ReLULayer, self).__init__()
def forward(self, input):
return tf.nn.relu(input)
class DenseLayer(Module):
def __init__(self, num_channels, bottleneck, growth_rate, dropout_rate):
super(DenseLayer, self).__init__()
self.bottleneck = bottleneck
self.dropout_rate = dropout_rate
self.growth_rate = growth_rate
self.layers_list = []
self.layers_list.append(BatchNorm2d(num_features=num_channels))
self.layers_list.append(ReLULayer())
num_features = num_channels
if self.bottleneck == True:
num_features = 4 * growth_rate
self.layers_list.append(Conv2d(n_filter=num_features, in_channels=num_channels, padding='VALID',filter_size=(1, 1), W_init=tl.initializers.truncated_normal(stddev=5e-2)))
self.layers_list.append(BatchNorm2d(num_features=num_features))
self.layers_list.append(ReLULayer())
self.layers_list.append(Conv2d(n_filter=self.growth_rate, in_channels=num_features, filter_size=(3, 3), W_init=tl.initializers.truncated_normal(stddev=5e-2)))
if self.dropout_rate > 0:
self.layers_list.append(Dropout(keep=1-dropout_rate))
self.denseLayer = SequentialLayer(self.layers_list)
self.concat = Concat(concat_dim = 3)
# self.BN = BatchNorm2d()
# self.BN1 = BatchNorm2d()
# self.bn_conv = Conv2d(n_filter=4*num_channels, filter_size=(1, 1), act='relu')
# self.BN2 = BatchNorm2d()
# self.conv = Conv2d(n_filter=self.growth_rate, in_channels=4*num_channels, filter_size=(3, 3), act='relu')
# self.dropout = Dropout(keep=1-dropout_rate)
# self.concat = Concat()
def forward(self, input):
# print(type(input))
storeInput = tf.identity(input)
output = self.denseLayer(input)
return self.concat([storeInput, output])
# print('storeInput shape ', storeInput.shape)
# print('input shape is: ', input.shape)
# d = self.BN(input)
# input = tf.nn.relu(d)
# if self.bottleneck == True:
# input = self.bn_conv(input)
# if self.dropout_rate > 0:
# input = self.dropout(input)
# input = self.BN1(input)
# input = tf.nn.relu(input)
# print('after bottleneck, shape is ', input.shape)
f1 = self.BN2(input)
f2 = tf.nn.relu(f1)
output = self.conv(f2)
# print('after two level conv, shape is ', output.shape)
if self.dropout_rate > 0:
output = self.dropout(output)
# print('finish, input shape %s, output shape %s' % (input.shape, output.shape))
# print('finish, storeInput shape ', storeInput.shape)
return self.concat([storeInput, output])
class DenseBlock(Module):
def __init__(self, num_blocks, num_features, growth_rate, bottleneck, dropout_rate):
super(DenseBlock, self).__init__()
self.layers_list = []
for _ in range(num_blocks):
self.layers_list.append(DenseLayer(num_features, bottleneck, growth_rate, dropout_rate))
num_features += growth_rate
self.denseBlock = SequentialLayer(self.layers_list)
def forward(self, x):
return self.denseBlock(x)
class TransitionLayer(Module):
def __init__(self, num_channels, dropout_rate, compression):
super(TransitionLayer, self).__init__()
self.dropout_rate = dropout_rate
self.layers_list = []
self.layers_list.append(BatchNorm2d(num_features=num_channels))
self.layers_list.append(ReLULayer())
self.layers_list.append(Conv2d(n_filter=int(num_channels*compression), in_channels=num_channels, filter_size=(1, 1)))
self.layers_list.append(MeanPool2d(filter_size=(2, 2), strides=(2, 2) , padding= 'VALID'))
self.transitionLayer = SequentialLayer(self.layers_list)
# self.BN = BatchNorm2d()
# self.conv = Conv2d(n_filter=int(num_channels*compression), filter_size=(1, 1))
# self.pool = MeanPool2d(filter_size=(2, 2), strides=(2, 2))
# self.drop = Dropout(keep=1-dropout_rate)
def forward(self, input):
return self.transitionLayer(input)
# d1 = self.BN(input)
# d2 = tf.nn.relu(d1)
# if self.last == True:
# output = self.pool(d2)
# else:
# d3 = self.conv(d2)
# if self.dropout_rate > 0:
# d3 = self.drop(d3)
# output = self.pool(d3)
# return output
class DenseNet_model(Module):
def __init__(self, num_classes, dense_layers=None, growth_rate=12, dataset='imagenet', bottleneck=False, dropout_rate=0.2, compression=0):
super(DenseNet_model, self).__init__()
if num_classes == None:
logging.error('The number for classification is empty.')
sys.exit(0)
if dense_layers == None:
logging.error('You need to specifiy the structure for dense layer.')
sys.exit(0)
self.num_classes = num_classes
self.dense_layers = dense_layers
self.growth_rate = growth_rate
self.bottleneck = bottleneck
self.dropout_rate = dropout_rate
self.compression = compression
self.dataset = dataset
# initialize the num_channels
self.num_channels = 2 * growth_rate
self.layers_list = []
if self.dataset == 'cifar':
self.layers_list.append(Conv2d(self.num_channels, in_channels=3, filter_size=(3, 3), W_init=tl.initializers.truncated_normal(stddev=5e-2)))
elif self.dataset == 'imagenet':
self.layers_list.append(Conv2d(self.num_channels, in_channels=3, filter_size=(7, 7), strides=(2, 2), W_init=tl.initializers.truncated_normal(stddev=5e-2)))
self.layers_list.append(BatchNorm2d(num_features=self.num_channels))
self.layers_list.append(ReLULayer())
self.layers_list.append(MaxPool2d((3, 3), strides=(2, 2)))
else:
logging.error('Incorrect dataset: %s' % self.dataset)
sys.exit(1)
for i, num_layers in enumerate(self.dense_layers):
self.layers_list.append(DenseBlock(num_layers, self.num_channels, self.growth_rate, self.bottleneck, self.dropout_rate))
self.num_channels += self.growth_rate * num_layers
if i != len(self.dense_layers) - 1:
self.layers_list.append(TransitionLayer(self.num_channels, self.dropout_rate, self.compression))
self.num_channels = int(self.num_channels * self.compression)
self.layers_list.append(BatchNorm2d(num_features=self.num_channels))
# self.layers_list.append(ReLULayer())
self.layers_list.append(GlobalMeanPool2d())
self.layers_list.append(Dense(self.num_classes, act="softmax", W_init=tl.initializers.truncated_normal(stddev=0.04)))
self.model = SequentialLayer(self.layers_list)
# self.cifar_init_conv = Conv2d(self.num_channels, filter_size=(3, 3))
# self.image_init_conv = Conv2d(self.num_channels, filter_size=(7, 7), strides=(2, 2))
# self.BN = BatchNorm2d()
# self.maxpool = MaxPool2d((3, 3), strides=(2, 2))
# self.globalpool = GlobalMeanPool2d()
# self.BN2 = BatchNorm2d()
# self.dense = Dense(self.num_classes, act='softmax')
def forward(self, input):
return self.model(input)
# if self.dataset == 'cifar':
# input = self.cifar_init_conv(x)
# elif self.dataset == 'imagenet':
# d1 = self.image_init_conv(x)
# d2 = self.BN(d1)
# d3 = tf.nn.relu(d2)
# input = self.maxpool(d3)
# for i, num_layers in enumerate(self.dense_layers):
# last = False
# if i == len(self.dense_layers) - 1:
# last = True
# input = DenseBlock(num_layers, self.num_channels, self.growth_rate, self.bottleneck, self.dropout_rate)(input)
# self.num_channels += self.growth_rate * num_layers
# input = TransitionLayer(self.num_channels, self.dropout_rate, self.compression, last)(input)
# self.num_channels = int(self.num_channels*self.compression)
# output = BatchNorm2d(name="globalNorm")(input)
# output = self.globalpool(output)
# output = Dense(self.num_classes, in_channels=output.shape[1])(output)
# return output
def DenseNetBC_model(num_classes, dense_layers, growth_rate, dataset='imagenet'):
return DenseNet_model(num_classes, dense_layers, growth_rate, dataset, bottleneck=True, compression=0.5)
def DenseNet100_model():
return DenseNetBC_model(10, [16, 16, 16], 12, dataset='cifar')
def DenseNet121_model():
return DenseNetBC_model(1000, [6, 12, 24, 16], 32)
def DenseNet161_model():
return DenseNetBC_model(1000, [6, 12, 48, 32], 32)
def DenseNet169_model():
return DenseNetBC_model(1000, [6, 12, 32, 32], 32)
def DenseNet201_model():
return DenseNetBC_model(1000, [6, 12, 36, 24], 48)
def DenseNet(name='densenet-100'):
if name == 'densenet-100':
return DenseNet100_model()
elif name == 'densenet-121':
return DenseNet121_model()
elif name == 'densenet-161':
return DenseNet161_model()
elif name == 'densenet-169':
return DenseNet169_model()
elif name == 'densenet-201':
return DenseNet201_model()
else:
logging.error("Invalid mode for Densenet: %s" % name)
sys.exit(1)