tensorlayer3/tensorlayer/layers/convolution/quan_conv.py

175 lines
6.1 KiB
Python

#! /usr/bin/python
# -*- coding: utf-8 -*-
import tensorlayer as tl
from tensorlayer import logging
from tensorlayer.layers.core import Module
from tensorlayer.layers.utils import (quantize_active_overflow, quantize_weight_overflow)
__all__ = ['QuanConv2d']
class QuanConv2d(Module):
"""The :class:`QuanConv2d` class is a quantized convolutional layer without BN, which weights are 'bitW' bits and the output of the previous layer
are 'bitA' bits while inferencing.
Note that, the bias vector would not be binarized.
Parameters
----------
With TensorLayer
n_filter : int
The number of filters.
filter_size : tuple of int
The filter size (height, width).
strides : tuple of int
The sliding window strides of corresponding input dimensions.
It must be in the same order as the ``shape`` parameter.
act : activation function
The activation function of this layer.
padding : str
The padding algorithm type: "SAME" or "VALID".
bitW : int
The bits of this layer's parameter
bitA : int
The bits of the output of previous layer
use_gemm : boolean
If True, use gemm instead of ``tf.matmul`` for inference.
TODO: support gemm
data_format : str
"channels_last" (NHWC, default) or "channels_first" (NCHW).
dilation_rate : tuple of int
Specifying the dilation rate to use for dilated convolution.
W_init : initializer
The initializer for the the weight matrix.
b_init : initializer or None
The initializer for the the bias vector. If None, skip biases.
in_channels : int
The number of in channels.
name : None or str
A unique layer name.
Examples
---------
With TensorLayer
>>> net = tl.layers.Input([8, 12, 12, 64], name='input')
>>> quanconv2d = tl.layers.QuanConv2d(
... n_filter=32, filter_size=(5, 5), strides=(1, 1), act=tl.ReLU, padding='SAME', name='quancnn2d'
... )(net)
>>> print(quanconv2d)
>>> output shape : (8, 12, 12, 32)
"""
def __init__(
self,
bitW=8,
bitA=8,
n_filter=32,
filter_size=(3, 3),
strides=(1, 1),
act=None,
padding='SAME',
use_gemm=False,
data_format="channels_last",
dilation_rate=(1, 1),
W_init=tl.initializers.truncated_normal(stddev=0.02),
b_init=tl.initializers.constant(value=0.0),
in_channels=None,
name=None # 'quan_cnn2d',
):
super().__init__(name, act=act)
self.bitW = bitW
self.bitA = bitA
self.n_filter = n_filter
self.filter_size = filter_size
self.strides = self._strides = strides
self.padding = padding
self.use_gemm = use_gemm
self.data_format = data_format
self.dilation_rate = self._dilation_rate = dilation_rate
self.W_init = W_init
self.b_init = b_init
self.in_channels = in_channels
if self.in_channels:
self.build(None)
self._built = True
logging.info(
"QuanConv2d %s: n_filter: %d filter_size: %s strides: %s pad: %s act: %s" % (
self.name, n_filter, str(filter_size), str(strides), padding,
self.act.__class__.__name__ if self.act is not None else 'No Activation'
)
)
if self.use_gemm:
raise Exception("TODO. The current version use tf.matmul for inferencing.")
if len(self.strides) != 2:
raise ValueError("len(strides) should be 2.")
def __repr__(self):
actstr = self.act.__name__ if self.act is not None else 'No Activation'
s = (
'{classname}(in_channels={in_channels}, out_channels={n_filter}, kernel_size={filter_size}'
', strides={strides}, padding={padding}'
)
if self.dilation_rate != (1, ) * len(self.dilation_rate):
s += ', dilation={dilation_rate}'
if self.b_init is None:
s += ', bias=False'
s += (', ' + actstr)
if self.name is not None:
s += ', name=\'{name}\''
s += ')'
return s.format(classname=self.__class__.__name__, **self.__dict__)
def build(self, inputs_shape):
if self.data_format == 'channels_last':
self.data_format = 'NHWC'
if self.in_channels is None:
self.in_channels = inputs_shape[-1]
self._strides = [1, self._strides[0], self._strides[1], 1]
self._dilation_rate = [1, self._dilation_rate[0], self._dilation_rate[1], 1]
elif self.data_format == 'channels_first':
self.data_format = 'NCHW'
if self.in_channels is None:
self.in_channels = inputs_shape[1]
self._strides = [1, 1, self._strides[0], self._strides[1]]
self._dilation_rate = [1, 1, self._dilation_rate[0], self._dilation_rate[1]]
else:
raise Exception("data_format should be either channels_last or channels_first")
self.filter_shape = (self.filter_size[0], self.filter_size[1], self.in_channels, self.n_filter)
self.W = self._get_weights("filters", shape=self.filter_shape, init=self.W_init)
if self.b_init:
self.b = self._get_weights("biases", shape=(self.n_filter, ), init=self.b_init)
self.bias_add = tl.ops.BiasAdd(data_format=self.data_format)
self.conv2d = tl.ops.Conv2D(
strides=self.strides, padding=self.padding, data_format=self.data_format, dilations=self._dilation_rate
)
def forward(self, inputs):
if self._forward_state == False:
if self._built == False:
self.build(tl.get_tensor_shape(inputs))
self._built = True
self._forward_state = True
inputs = quantize_active_overflow(inputs, self.bitA)
W_ = quantize_weight_overflow(self.W, self.bitW)
outputs = self.conv2d(inputs, W_)
if self.b_init:
outputs = self.bias_add(outputs, self.b)
if self.act:
outputs = self.act(outputs)
return outputs