tensorlayer3/Densnet.py

124 lines
4.8 KiB
Python

import tensorlayer as tl
from tensorlayer.layers import Module, Dense, Elementwise, SequentialLayer, Flatten, BatchNorm2d, Conv2d, Dropout, Concat, MeanPool2d
import math
__all__ = ['dnet100', 'dnet121']
class BasicBlk(Module):
def __init__(self, in_planes, out_planes, droprate = 0.0):
super(BasicBlk, self).__init__()
self.bn1 = BatchNorm2d(in_channels=in_planes, act=tl.ReLU)
self.conv1 = Conv2d(n_filter=out_planes, in_channels=in_planes, filter_size=(3, 3), strides=(1, 1), padding='SAME', b_init=None)
self.droprate = droprate
self.drop = Dropout(keep=self.droprate)
self.cat = Concat()
def forward(self, inputs):
z = self.bn1(inputs)
z1 = self.conv1(z)
if self.droprate > 0:
z1 = self.drop(z1)
out = self.cat([inputs, z1])
return out
class BtnBlk(Module):
def __init__(self, in_planes, out_planes, dropRate=0.0):
super(BtnBlk, self).__init__()
inter_planes = out_planes * 4
self.bn1 = BatchNorm2d( act=tl.ReLU)
self.conv1 = Conv2d(n_filter=inter_planes, in_channels=in_planes, filter_size=(1, 1), strides=(1, 1), padding='VALID', b_init=None)
self.bn2 = BatchNorm2d()
self.conv2 = Conv2d(n_filter=out_planes, in_channels=inter_planes, filter_size=(3, 3), strides=(1, 1), padding='SAME', b_init=None)
self.droprate = dropRate
self.drop = Dropout(keep=self.droprate)
self.cat = Concat()
def forward(self, x):
out = self.conv1(self.bn1(x))
if self.droprate > 0:
out = self.drop(out)
out = self.conv2(self.bn2(out))
if self.droprate > 0:
out = self.drop(out)
return self.cat([x, out])
class TrBlk(Module):
def __init__(self, in_planes, out_planes, dropRate=0.0):
super(TrBlk, self).__init__()
self.bn1 = BatchNorm2d(act=tl.ReLU)
self.conv1 = Conv2d(n_filter=out_planes, in_channels=in_planes, filter_size=(1, 1), strides=(1, 1), padding='VALID', b_init=None)
self.droprate = dropRate
self.drop = Dropout(keep=self.droprate)
self.avg = MeanPool2d(filter_size=(2, 2), strides=(2, 2), padding='VALID')
def forward(self, x):
out = self.conv1(self.bn1(x))
if self.droprate > 0:
out = self.drop(out)
return self.avg(out)
class DBlk(Module):
def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):
super(DBlk, self).__init__()
self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)
def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):
layers = []
for i in range(nb_layers):
layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))
return SequentialLayer(layers)
def forward(self, x):
return self.layer(x)
class DNet(Module):
def __init__(self, depth, num_classes, growth_rate=12,
reduction=0.5, bottleneck=True, dropRate=0.0):
super(DNet, self).__init__()
in_planes = 2 * growth_rate
n = (depth - 4) / 3
if bottleneck == True:
n = n/2
block = BtnBlk
else:
block = BasicBlk
n = int(n)
# 1st conv before any dense block
self.conv1 = Conv2d(in_channels=3, n_filter=in_planes, filter_size=(3, 3), strides=(1, 1),
padding='SAME', b_init=None)
# 1st block
self.block1 = DBlk(n, in_planes, growth_rate, block, dropRate)
in_planes = int(in_planes+n*growth_rate)
self.trans1 = TrBlk(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
in_planes = int(math.floor(in_planes*reduction))
# 2nd block
self.block2 = DBlk(n, in_planes, growth_rate, block, dropRate)
in_planes = int(in_planes+n*growth_rate)
self.trans2 = TrBlk(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
in_planes = int(math.floor(in_planes*reduction))
# 3rd block
self.block3 = DBlk(n, in_planes, growth_rate, block, dropRate)
in_planes = int(in_planes+n*growth_rate)
# global average pooling and classifier
self.bn1 = BatchNorm2d( act=tl.ReLU)
self.fc = Dense(n_units=num_classes)
self.in_planes = in_planes
self.avg = MeanPool2d(filter_size=(8, 8), strides=(8, 8), padding='VALID')
self.flatten = Flatten()
def forward(self, x):
out = self.conv1(x)
out = self.trans1(self.block1(out))
out = self.trans2(self.block2(out))
out = self.block3(out)
out = self.bn1(out)
out = self.avg(out)
out = self.flatten(out)
return self.fc(out)
def dnet121():
return DNet(121, 1000)
def dnet100():
return DNet(100, 10)