tensorlayer3/Densnet.py

126 lines
4.9 KiB
Python

import tensorlayer as tl
from tensorlayer.layers import Module, Dense, Elementwise, SequentialLayer, Flatten, BatchNorm2d, Conv2d, Dropout, Concat, MeanPool2d
import math
__all__ = ['densnet100', 'densnet121']
class BasicBlock(Module):
def __init__(self, in_planes, out_planes, droprate = 0.0):
super(BasicBlock, self).__init__()
self.Bn1 = BatchNorm2d(in_channels=in_planes, act=tl.ReLU)
self.CONV1 = Conv2d(n_filter=out_planes, in_channels=in_planes, filter_size=(3, 3), strides=(1, 1), padding='SAME', b_init=None)
self.droprate = droprate
self.DROP = Dropout(keep=self.droprate)
self.CAT = Concat()
def forward(self, inputs):
s = self.Bn1(inputs)
s1 = self.CONV1(s)
if self.droprate > 0:
s1 = self.DROP(s1)
out = self.CAT([inputs, s1])
return out
class BottleneckBlock(Module):
def __init__(self, in_planes, out_planes, dropRate=0.0):
super(BottleneckBlock, self).__init__()
inter_planes = out_planes * 4
self.Bn1 = BatchNorm2d( act=tl.ReLU)
self.CONV1 = Conv2d(n_filter=inter_planes, in_channels=in_planes, filter_size=(1, 1), strides=(1, 1), padding='VALID', b_init=None)
self.Bn2 = BatchNorm2d()
self.CONV2 = Conv2d(n_filter=out_planes, in_channels=inter_planes, filter_size=(3, 3), strides=(1, 1), padding='SAME', b_init=None)
self.droprate = dropRate
self.DROP = Dropout(keep=self.droprate)
self.CAT = Concat()
def forward(self, x):
out = self.CONV1(self.Bn1(x))
if self.droprate > 0:
out = self.DROP(out)
out = self.CONV2(self.Bn2(out))
if self.droprate > 0:
out = self.DROP(out)
return self.CAT([x, out])
class TransitionBlock(Module):
def __init__(self, in_planes, out_planes, dropRate=0.0):
super(TransitionBlock, self).__init__()
self.Bn1 = BatchNorm2d(act=tl.ReLU)
self.CONV1 = Conv2d(n_filter=out_planes, in_channels=in_planes, filter_size=(1, 1), strides=(1, 1), padding='VALID', b_init=None)
self.droprate = dropRate
self.DROP= Dropout(keep=self.droprate)
self.avg = MeanPool2d(filter_size=(2, 2), strides=(2, 2), padding='VALID')
def forward(self, x):
out = self.CONV1(self.Bn1(x))
if self.droprate > 0:
out = self.DROP(out)
return self.avg(out)
class DenseBlock(Module):
def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):
super(DenseBlock, self).__init__()
self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)
def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):
layers = []
for i in range(nb_layers):
layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))
return SequentialLayer(layers)
def forward(self, x):
return self.layer(x)
class DenseNet(Module):
def __init__(self, depth, num_classes, growth_rate=12,
reduction=0.5, bottleneck=True, dropRate=0.0):
super(DenseNet, self).__init__()
in_planes = 2 * growth_rate
t = (depth - 4) / 3
if bottleneck == True:
t = t/2
block = BottleneckBlock
else:
block = BasicBlock
t = int(t)
# 1st conv before any dense block
self.CONV1 = Conv2d(in_channels=3, n_filter=in_planes, filter_size=(3, 3), strides=(1, 1),
padding='SAME', b_init=None)
# 1st block
self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
in_planes = int(in_planes+n*growth_rate)
self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
in_planes = int(math.floor(in_planes*reduction))
# 2nd block
self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
in_planes = int(in_planes+n*growth_rate)
self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
in_planes = int(math.floor(in_planes*reduction))
# 3rd block
self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
in_planes = int(in_planes+n*growth_rate)
# global average pooling and classifier
self.Bn1 = BatchNorm2d( act=tl.ReLU)
self.fc = Dense(n_units=num_classes)
self.in_planes = in_planes
self.avg = MeanPool2d(filter_size=(8, 8), strides=(8, 8), padding='VALID')
self.flatten = Flatten()
def forward(self, x):
out = self.CONV1(x)
out = self.trans1(self.block1(out))
out = self.trans2(self.block2(out))
out = self.block3(out)
out = self.Bn1(out)
out = self.avg(out)
out = self.flatten(out)
return self.fc(out)
def densnet121():
return DenseNet(121, 1000)
def densnet100():
return DenseNet(100, 10)