forked from TensorLayer/tensorlayer3
443 lines
13 KiB
Python
443 lines
13 KiB
Python
#! /usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import tensorflow as tf
|
|
from tensorflow.python.ops.rnn_cell import LSTMStateTuple
|
|
|
|
import tensorlayer as tl
|
|
from tensorlayer import logging
|
|
from tensorlayer.decorators import deprecated, deprecated_alias
|
|
from tensorlayer.backend.ops.load_backend import BACKEND
|
|
|
|
__all__ = [
|
|
'cabs',
|
|
'compute_alpha',
|
|
'flatten_reshape',
|
|
'get_collection_trainable',
|
|
'get_layers_with_name',
|
|
'get_variables_with_name',
|
|
'initialize_global_variables',
|
|
'initialize_rnn_state',
|
|
'list_remove_repeat',
|
|
'merge_networks',
|
|
'print_all_variables',
|
|
'quantize',
|
|
'quantize_active',
|
|
'quantize_weight',
|
|
'quantize_active_overflow',
|
|
'quantize_weight_overflow',
|
|
'set_name_reuse',
|
|
'ternary_operation',
|
|
]
|
|
|
|
########## Module Public Functions ##########
|
|
|
|
|
|
def cabs(x):
|
|
return tf.minimum(1.0, tf.abs(x), name='cabs')
|
|
|
|
|
|
def compute_alpha(x):
|
|
"""Computing the scale parameter."""
|
|
threshold = _compute_threshold(x)
|
|
alpha1_temp1 = tf.where(tf.greater(x, threshold), x, tf.zeros_like(x, tf.float32))
|
|
alpha1_temp2 = tf.where(tf.less(x, -threshold), x, tf.zeros_like(x, tf.float32))
|
|
alpha_array = tf.add(alpha1_temp1, alpha1_temp2, name=None)
|
|
alpha_array_abs = tf.abs(alpha_array)
|
|
alpha_array_abs1 = tf.where(
|
|
tf.greater(alpha_array_abs, 0), tf.ones_like(alpha_array_abs, tf.float32),
|
|
tf.zeros_like(alpha_array_abs, tf.float32)
|
|
)
|
|
alpha_sum = tf.reduce_sum(input_tensor=alpha_array_abs)
|
|
n = tf.reduce_sum(input_tensor=alpha_array_abs1)
|
|
# alpha = tf.compat.v1.div(alpha_sum, n)
|
|
alpha = tf.math.divide(alpha_sum, n)
|
|
return alpha
|
|
|
|
|
|
def flatten_reshape(variable, name='flatten'):
|
|
"""Reshapes a high-dimension vector input.
|
|
|
|
[batch_size, mask_row, mask_col, n_mask] ---> [batch_size, mask_row x mask_col x n_mask]
|
|
|
|
Parameters
|
|
----------
|
|
variable : TensorFlow variable or tensor
|
|
The variable or tensor to be flatten.
|
|
name : str
|
|
A unique layer name.
|
|
|
|
Returns
|
|
-------
|
|
Tensor
|
|
Flatten Tensor
|
|
|
|
"""
|
|
dim = 1
|
|
for d in tl.get_tensor_shape(variable)[1:]: # variable.get_shape()[1:].as_list():
|
|
dim *= d
|
|
return tl.reshape(variable, shape=[-1, dim])
|
|
|
|
|
|
def get_collection_trainable(name=''):
|
|
variables = []
|
|
for p in tf.compat.v1.trainable_variables():
|
|
# print(p.name.rpartition('/')[0], self.name)
|
|
if p.name.rpartition('/')[0] == name:
|
|
variables.append(p)
|
|
return variables
|
|
|
|
|
|
@deprecated_alias(printable='verbose', end_support_version=1.9) # TODO remove this line for the 1.9 release
|
|
def get_layers_with_name(net, name="", verbose=False):
|
|
"""Get a list of layers' output in a network by a given name scope.
|
|
|
|
Parameters
|
|
-----------
|
|
net : :class:`Layer`
|
|
The last layer of the network.
|
|
name : str
|
|
Get the layers' output that contain this name.
|
|
verbose : boolean
|
|
If True, print information of all the layers' output
|
|
|
|
Returns
|
|
--------
|
|
list of Tensor
|
|
A list of layers' output (TensorFlow tensor)
|
|
|
|
Examples
|
|
---------
|
|
>>> import tensorlayer as tl
|
|
>>> layers = tl.layers.get_layers_with_name(net, "CNN", True)
|
|
|
|
"""
|
|
logging.info(" [*] geting layers with %s" % name)
|
|
|
|
layers = []
|
|
i = 0
|
|
|
|
for layer in net.all_layers:
|
|
# logging.info(type(layer.name))
|
|
if name in layer.name:
|
|
layers.append(layer)
|
|
|
|
if verbose:
|
|
logging.info(" got {:3}: {:15} {}".format(i, layer.name, str(layer.get_shape())))
|
|
i = i + 1
|
|
|
|
return layers
|
|
|
|
|
|
def get_variable_with_initializer(scope_name, var_name, shape, init=tl.initializers.random_normal(), trainable=True):
|
|
# FIXME: documentation needed
|
|
var_name = scope_name + "/" + var_name
|
|
# FIXME: not sure whether this is correct?
|
|
# TODO mindspore weights shape : [out_channel, in_channel, kernel_h, kernel_w]
|
|
if BACKEND == 'mindspore':
|
|
if len(shape) == 2:
|
|
pass
|
|
else:
|
|
shape = shape[::-1]
|
|
|
|
initial_value = init(shape=shape)
|
|
|
|
if BACKEND == 'dragon':
|
|
return initial_value
|
|
|
|
var = tl.Variable(initial_value=initial_value, name=var_name, trainable=trainable)
|
|
return var
|
|
|
|
|
|
@deprecated_alias(printable='verbose', end_support_version=1.9) # TODO remove this line for the 1.9 release
|
|
def get_variables_with_name(name=None, train_only=True, verbose=False):
|
|
"""Get a list of TensorFlow variables by a given name scope.
|
|
|
|
Parameters
|
|
----------
|
|
name : str
|
|
Get the variables that contain this name.
|
|
train_only : boolean
|
|
If Ture, only get the trainable variables.
|
|
verbose : boolean
|
|
If True, print the information of all variables.
|
|
|
|
Returns
|
|
-------
|
|
list of Tensor
|
|
A list of TensorFlow variables
|
|
|
|
Examples
|
|
--------
|
|
>>> import tensorlayer as tl
|
|
>>> dense_vars = tl.layers.get_variables_with_name('dense', True, True)
|
|
|
|
"""
|
|
if name is None:
|
|
raise Exception("please input a name")
|
|
|
|
logging.info(" [*] geting variables with %s" % name)
|
|
|
|
# tvar = tf.trainable_variables() if train_only else tf.all_variables()
|
|
if train_only:
|
|
t_vars = tf.compat.v1.trainable_variables()
|
|
|
|
else:
|
|
t_vars = tf.compat.v1.global_variables()
|
|
|
|
d_vars = [var for var in t_vars if name in var.name]
|
|
|
|
if verbose:
|
|
for idx, v in enumerate(d_vars):
|
|
logging.info(" got {:3}: {:15} {}".format(idx, v.name, str(v.get_shape())))
|
|
|
|
return d_vars
|
|
|
|
|
|
@deprecated(
|
|
date="2018-09-30", instructions="This API is deprecated in favor of `sess.run(tf.global_variables_initializer())`"
|
|
)
|
|
def initialize_global_variables(sess):
|
|
"""Initialize the global variables of TensorFlow.
|
|
|
|
Run ``sess.run(tf.global_variables_initializer())`` for TF 0.12+ or
|
|
``sess.run(tf.initialize_all_variables())`` for TF 0.11.
|
|
|
|
Parameters
|
|
----------
|
|
sess : Session
|
|
TensorFlow session.
|
|
|
|
"""
|
|
if sess is None:
|
|
raise AssertionError('The session must be defined')
|
|
|
|
sess.run(tf.compat.v1.global_variables_initializer())
|
|
|
|
|
|
def initialize_rnn_state(state, feed_dict=None):
|
|
"""Returns the initialized RNN state.
|
|
The inputs are `LSTMStateTuple` or `State` of `RNNCells`, and an optional `feed_dict`.
|
|
|
|
Parameters
|
|
----------
|
|
state : RNN state.
|
|
The TensorFlow's RNN state.
|
|
feed_dict : dictionary
|
|
Initial RNN state; if None, returns zero state.
|
|
|
|
Returns
|
|
-------
|
|
RNN state
|
|
The TensorFlow's RNN state.
|
|
|
|
"""
|
|
if isinstance(state, LSTMStateTuple):
|
|
c = state.c.eval(feed_dict=feed_dict)
|
|
h = state.h.eval(feed_dict=feed_dict)
|
|
return c, h
|
|
else:
|
|
new_state = state.eval(feed_dict=feed_dict)
|
|
return new_state
|
|
|
|
|
|
def list_remove_repeat(x):
|
|
"""Remove the repeated items in a list, and return the processed list.
|
|
You may need it to create merged layer like Concat, Elementwise and etc.
|
|
|
|
Parameters
|
|
----------
|
|
x : list
|
|
Input
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
A list that after removing it's repeated items
|
|
|
|
Examples
|
|
-------
|
|
>>> l = [2, 3, 4, 2, 3]
|
|
>>> l = list_remove_repeat(l)
|
|
[2, 3, 4]
|
|
|
|
"""
|
|
y = []
|
|
for i in x:
|
|
if i not in y:
|
|
y.append(i)
|
|
|
|
return y
|
|
|
|
|
|
def merge_networks(layers=None):
|
|
"""Merge all parameters, layers and dropout probabilities to a :class:`Layer`.
|
|
The output of return network is the first network in the list.
|
|
|
|
Parameters
|
|
----------
|
|
layers : list of :class:`Layer`
|
|
Merge all parameters, layers and dropout probabilities to the first layer in the list.
|
|
|
|
Returns
|
|
--------
|
|
:class:`Layer`
|
|
The network after merging all parameters, layers and dropout probabilities to the first network in the list.
|
|
|
|
Examples
|
|
---------
|
|
>>> import tensorlayer as tl
|
|
>>> n1 = ...
|
|
>>> n2 = ...
|
|
>>> n1 = tl.layers.merge_networks([n1, n2])
|
|
|
|
"""
|
|
if layers is None:
|
|
raise Exception("layers should be a list of TensorLayer's Layers.")
|
|
layer = layers[0]
|
|
|
|
all_params = []
|
|
all_layers = []
|
|
all_drop = {}
|
|
|
|
for l in layers:
|
|
all_params.extend(l.all_params)
|
|
all_layers.extend(l.all_layers)
|
|
all_drop.update(l.all_drop)
|
|
|
|
layer.all_params = list(all_params)
|
|
layer.all_layers = list(all_layers)
|
|
layer.all_drop = dict(all_drop)
|
|
|
|
layer.all_layers = list_remove_repeat(layer.all_layers)
|
|
layer.all_params = list_remove_repeat(layer.all_params)
|
|
|
|
return layer
|
|
|
|
|
|
def print_all_variables(train_only=False):
|
|
"""Print information of trainable or all variables,
|
|
without ``tl.layers.initialize_global_variables(sess)``.
|
|
|
|
Parameters
|
|
----------
|
|
train_only : boolean
|
|
Whether print trainable variables only.
|
|
- If True, print the trainable variables.
|
|
- If False, print all variables.
|
|
|
|
"""
|
|
# tvar = tf.trainable_variables() if train_only else tf.all_variables()
|
|
if train_only:
|
|
t_vars = tf.compat.v1.trainable_variables()
|
|
logging.info(" [*] printing trainable variables")
|
|
|
|
else:
|
|
t_vars = tf.compat.v1.global_variables()
|
|
logging.info(" [*] printing global variables")
|
|
|
|
for idx, v in enumerate(t_vars):
|
|
logging.info(" var {:3}: {:15} {}".format(idx, str(v.get_shape()), v.name))
|
|
|
|
|
|
def quantize(x):
|
|
# ref: https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L70
|
|
# https://github.com/itayhubara/BinaryNet.tf/blob/master/nnUtils.py
|
|
with tf.compat.v1.get_default_graph().gradient_override_map({"Sign": "TL_Sign_QuantizeGrad"}):
|
|
return tf.sign(x)
|
|
|
|
|
|
def quantize_active(x, bitA):
|
|
if bitA == 32:
|
|
return x
|
|
return _quantize_dorefa(x, bitA)
|
|
|
|
|
|
def quantize_weight(x, bitW, force_quantization=False):
|
|
G = tf.compat.v1.get_default_graph()
|
|
if bitW == 32 and not force_quantization:
|
|
return x
|
|
if bitW == 1: # BWN
|
|
with G.gradient_override_map({"Sign": "Identity"}):
|
|
E = tf.stop_gradient(tf.reduce_mean(input_tensor=tf.abs(x)))
|
|
return tf.sign(x / E) * E
|
|
x = tf.clip_by_value(x * 0.5 + 0.5, 0.0, 1.0) # it seems as though most weights are within -1 to 1 region anyways
|
|
return 2 * _quantize_dorefa(x, bitW) - 1
|
|
|
|
|
|
def quantize_active_overflow(x, bitA):
|
|
if bitA == 32:
|
|
return x
|
|
return _quantize_overflow(x, bitA)
|
|
|
|
|
|
def quantize_weight_overflow(x, bitW):
|
|
if bitW == 32:
|
|
return x
|
|
return _quantize_overflow(x, bitW)
|
|
|
|
|
|
@deprecated(date="2018-06-30", instructions="TensorLayer relies on TensorFlow to check name reusing")
|
|
def set_name_reuse(enable=True):
|
|
logging.warning('this method is DEPRECATED and has no effect, please remove it from your code.')
|
|
|
|
|
|
def ternary_operation(x):
|
|
"""Ternary operation use threshold computed with weights."""
|
|
g = tf.compat.v1.get_default_graph()
|
|
with g.gradient_override_map({"Sign": "Identity"}):
|
|
threshold = _compute_threshold(x)
|
|
x = tf.sign(tf.add(tf.sign(tf.add(x, threshold)), tf.sign(tf.add(x, -threshold))))
|
|
return x
|
|
|
|
|
|
########## Module Private Functions ##########
|
|
|
|
|
|
@tf.RegisterGradient("TL_Sign_QuantizeGrad")
|
|
def _quantize_grad(op, grad):
|
|
"""Clip and binarize tensor using the straight through estimator (STE) for the gradient."""
|
|
return tf.clip_by_value(grad, -1, 1)
|
|
|
|
|
|
def _quantize_dorefa(x, k):
|
|
G = tf.compat.v1.get_default_graph()
|
|
n = float(2**k - 1)
|
|
with G.gradient_override_map({"Round": "Identity"}):
|
|
return tf.round(x * n) / n
|
|
|
|
|
|
def _quantize_overflow(x, k):
|
|
G = tf.compat.v1.get_default_graph()
|
|
n = float(2**k - 1)
|
|
max_value = tf.reduce_max(input_tensor=x)
|
|
min_value = tf.reduce_min(input_tensor=x)
|
|
with G.gradient_override_map({"Round": "Identity"}):
|
|
step = tf.stop_gradient((max_value - min_value) / n)
|
|
return tf.round((tf.maximum(tf.minimum(x, max_value), min_value) - min_value) / step) * step + min_value
|
|
|
|
|
|
def _compute_threshold(x):
|
|
"""
|
|
ref: https://github.com/XJTUWYD/TWN
|
|
Computing the threshold.
|
|
"""
|
|
x_sum = tf.reduce_sum(input_tensor=tf.abs(x), axis=None, keepdims=False, name=None)
|
|
# threshold = tf.compat.v1.div(x_sum, tf.cast(tf.size(input=x), tf.float32), name=None)
|
|
threshold = tf.math.divide(x_sum, tf.cast(tf.size(input=x), tf.float32), name=None)
|
|
threshold = tf.multiply(0.7, threshold, name=None)
|
|
return threshold
|
|
|
|
|
|
def mean_var_with_update(update_moving_mean, update_moving_variance, mean, variance):
|
|
with tf.control_dependencies([update_moving_mean, update_moving_variance]):
|
|
return tf.identity(mean), tf.identity(variance)
|
|
|
|
|
|
def w_fold(w, gama, var, epsilon):
|
|
return tf.compat.v1.div(tf.multiply(gama, w), tf.sqrt(var + epsilon))
|
|
|
|
|
|
def bias_fold(beta, gama, mean, var, epsilon):
|
|
return tf.subtract(beta, tf.compat.v1.div(tf.multiply(gama, mean), tf.sqrt(var + epsilon)))
|