tensorlayer3/tensorlayer/layers/convolution/quan_conv_bn.py

241 lines
9.2 KiB
Python

#! /usr/bin/python
# -*- coding: utf-8 -*-
import tensorflow as tf
import tensorlayer as tl
from tensorlayer import logging
from tensorlayer.layers.core import Module
from tensorflow.python.training import moving_averages
from tensorlayer.layers.utils import (quantize_active_overflow, quantize_weight_overflow)
from tensorlayer.backend import BACKEND
__all__ = ['QuanConv2dWithBN']
class QuanConv2dWithBN(Module):
"""The :class:`QuanConv2dWithBN` class is a quantized convolutional layer with BN, which weights are 'bitW' bits and the output of the previous layer
are 'bitA' bits while inferencing.
Note that, the bias vector would keep the same.
Parameters
----------
n_filter : int
The number of filters.
filter_size : tuple of int
The filter size (height, width).
strides : tuple of int
The sliding window strides of corresponding input dimensions.
It must be in the same order as the ``shape`` parameter.
padding : str
The padding algorithm type: "SAME" or "VALID".
act : activation function
The activation function of this layer.
decay : float
A decay factor for `ExponentialMovingAverage`.
Suggest to use a large value for large dataset.
epsilon : float
Eplison.
is_train : boolean
Is being used for training or inference.
beta_init : initializer or None
The initializer for initializing beta, if None, skip beta.
Usually you should not skip beta unless you know what happened.
gamma_init : initializer or None
The initializer for initializing gamma, if None, skip gamma.
bitW : int
The bits of this layer's parameter
bitA : int
The bits of the output of previous layer
use_gemm : boolean
If True, use gemm instead of ``tf.matmul`` for inferencing. (TODO).
W_init : initializer
The initializer for the the weight matrix.
W_init_args : dictionary
The arguments for the weight matrix initializer.
data_format : str
"NHWC" or "NCHW", default is "NHWC".
dilation_rate : tuple of int
Specifying the dilation rate to use for dilated convolution.
in_channels : int
The number of in channels.
name : str
A unique layer name.
Examples
---------
>>> import tensorlayer as tl
>>> net = tl.layers.Input([50, 256, 256, 3])
>>> layer = tl.layers.QuanConv2dWithBN(n_filter=64, filter_size=(5,5),strides=(1,1),padding='SAME',name='qcnnbn1')
>>> print(layer)
>>> net = tl.layers.QuanConv2dWithBN(n_filter=64, filter_size=(5,5),strides=(1,1),padding='SAME',name='qcnnbn1')(net)
>>> print(net)
"""
def __init__(
self,
n_filter=32,
filter_size=(3, 3),
strides=(1, 1),
padding='SAME',
act=None,
decay=0.9,
epsilon=1e-5,
is_train=False,
gamma_init=tl.initializers.truncated_normal(stddev=0.02),
beta_init=tl.initializers.truncated_normal(stddev=0.02),
bitW=8,
bitA=8,
use_gemm=False,
W_init=tl.initializers.truncated_normal(stddev=0.02),
W_init_args=None,
data_format="channels_last",
dilation_rate=(1, 1),
in_channels=None,
name='quan_cnn2d_bn',
):
super(QuanConv2dWithBN, self).__init__(act=act, name=name)
self.n_filter = n_filter
self.filter_size = filter_size
self.strides = strides
self.padding = padding
self.decay = decay
self.epsilon = epsilon
self.is_train = is_train
self.gamma_init = gamma_init
self.beta_init = beta_init
self.bitW = bitW
self.bitA = bitA
self.use_gemm = use_gemm
self.W_init = W_init
self.W_init_args = W_init_args
self.data_format = data_format
self.dilation_rate = dilation_rate
self.in_channels = in_channels
logging.info(
"QuanConv2dWithBN %s: n_filter: %d filter_size: %s strides: %s pad: %s act: %s " % (
self.name, n_filter, filter_size, str(strides), padding,
self.act.__class__.__name__ if self.act is not None else 'No Activation'
)
)
if BACKEND == 'mindspore':
raise NotImplementedError("MindSpore backend does not implement this method")
if self.in_channels:
self.build(None)
self._built = True
if use_gemm:
raise Exception("TODO. The current version use tf.matmul for inferencing.")
if len(strides) != 2:
raise ValueError("len(strides) should be 2.")
def __repr__(self):
actstr = self.act.__class__.__name__ if self.act is not None else 'No Activation'
s = (
'{classname}(in_channels={in_channels}, out_channels={n_filter}, kernel_size={filter_size}'
', strides={strides}, padding={padding}' + actstr
)
if self.dilation_rate != (1, ) * len(self.dilation_rate):
s += ', dilation={dilation_rate}'
if self.name is not None:
s += ', name=\'{name}\''
s += ')'
return s.format(classname=self.__class__.__name__, **self.__dict__)
def build(self, inputs_shape):
if self.data_format == 'channels_last':
self.data_format = 'NHWC'
if self.in_channels is None:
self.in_channels = inputs_shape[-1]
self._strides = [1, self.strides[0], self.strides[1], 1]
self._dilation_rate = [1, self.dilation_rate[0], self.dilation_rate[1], 1]
elif self.data_format == 'channels_first':
self.data_format = 'NCHW'
if self.in_channels is None:
self.in_channels = inputs_shape[1]
self._strides = [1, 1, self.strides[0], self.strides[1]]
self._dilation_rate = [1, 1, self.dilation_rate[0], self.dilation_rate[1]]
else:
raise Exception("data_format should be either channels_last or channels_first")
self.filter_shape = (self.filter_size[0], self.filter_size[1], self.in_channels, self.n_filter)
self.W = self._get_weights("filters", shape=self.filter_shape, init=self.W_init)
para_bn_shape = (self.n_filter, )
if self.gamma_init:
self.scale_para = self._get_weights(
"scale_para", shape=para_bn_shape, init=self.gamma_init, trainable=self.is_train
)
else:
self.scale_para = None
if self.beta_init:
self.offset_para = self._get_weights(
"offset_para", shape=para_bn_shape, init=self.beta_init, trainable=self.is_train
)
else:
self.offset_para = None
self.moving_mean = self._get_weights(
"moving_mean", shape=para_bn_shape, init=tl.initializers.constant(1.0), trainable=False
)
self.moving_variance = self._get_weights(
"moving_variance", shape=para_bn_shape, init=tl.initializers.constant(1.0), trainable=False
)
def forward(self, inputs):
if self._forward_state == False:
if self._built == False:
self.build(tl.get_tensor_shape(inputs))
self._built = True
self._forward_state = True
x = inputs
inputs = quantize_active_overflow(inputs, self.bitA) # Do not remove
outputs = tf.nn.conv2d(
input=x, filters=self.W, strides=self._strides, padding=self.padding, data_format=self.data_format,
dilations=self._dilation_rate, name=self.name
)
mean, variance = tf.nn.moments(outputs, axes=list(range(len(outputs.get_shape()) - 1)))
update_moving_mean = moving_averages.assign_moving_average(
self.moving_mean, mean, self.decay, zero_debias=False
) # if zero_debias=True, has bias
update_moving_variance = moving_averages.assign_moving_average(
self.moving_variance, mean, self.decay, zero_debias=False
) # if zero_debias=True, has bias
if self.is_train:
mean, var = self.mean_var_with_update(update_moving_mean, update_moving_variance, mean, variance)
else:
mean, var = self.moving_mean, self.moving_variance
w_fold = self._w_fold(self.W, self.scale_para, var, self.epsilon)
W_ = quantize_weight_overflow(w_fold, self.bitW)
conv_fold = tf.nn.conv2d(inputs, W_, strides=self.strides, padding=self.padding, data_format=self.data_format)
if self.beta_init:
bias_fold = self._bias_fold(self.offset_para, self.scale_para, mean, var, self.epsilon)
conv_fold = tf.nn.bias_add(conv_fold, bias_fold, name='bn_bias_add')
if self.act:
conv_fold = self.act(conv_fold)
return conv_fold
def mean_var_with_update(self, update_moving_mean, update_moving_variance, mean, variance):
with tf.control_dependencies([update_moving_mean, update_moving_variance]):
return tf.identity(mean), tf.identity(variance)
def _w_fold(self, w, gama, var, epsilon):
return tf.compat.v1.div(tf.multiply(gama, w), tf.sqrt(var + epsilon))
def _bias_fold(self, beta, gama, mean, var, epsilon):
return tf.subtract(beta, tf.compat.v1.div(tf.multiply(gama, mean), tf.sqrt(var + epsilon)))