From 6c053306edb32f682a7de73a41b8b55de51e39d6 Mon Sep 17 00:00:00 2001 From: Eric_lai Date: Mon, 9 Aug 2021 09:34:23 +0800 Subject: [PATCH] Update TensorLayer3 --- .../tutorial_imdb_LSTM_simple.py | 74 + tensorlayer/backend/ops/__init__.py | 4 + tensorlayer/backend/ops/mindspore_backend.py | 21 +- tensorlayer/backend/ops/mindspore_nn.py | 149 ++ tensorlayer/backend/ops/paddle_backend.py | 131 +- tensorlayer/backend/ops/paddle_nn.py | 389 ++++ tensorlayer/backend/ops/tensorflow_backend.py | 9 +- tensorlayer/backend/ops/tensorflow_nn.py | 340 ++++ tensorlayer/files/utils.py | 7 + tensorlayer/layers/padding.py | 4 +- tensorlayer/layers/recurrent.py | 1727 ++++++----------- 11 files changed, 1653 insertions(+), 1202 deletions(-) create mode 100644 examples/basic_tutorials/tutorial_imdb_LSTM_simple.py diff --git a/examples/basic_tutorials/tutorial_imdb_LSTM_simple.py b/examples/basic_tutorials/tutorial_imdb_LSTM_simple.py new file mode 100644 index 0000000..d8149f6 --- /dev/null +++ b/examples/basic_tutorials/tutorial_imdb_LSTM_simple.py @@ -0,0 +1,74 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +# The same set of code can switch the backend with one line +import os +os.environ['TL_BACKEND'] = 'tensorflow' +# os.environ['TL_BACKEND'] = 'mindspore' +# os.environ['TL_BACKEND'] = 'paddle' + +import tensorlayer as tl +from tensorlayer.layers import Module +from tensorlayer.layers import Dense, LSTM, Embedding +from tensorlayer.dataflow import Dataset +import numpy as np + +X_train, y_train, X_test, y_test = tl.files.load_imdb_dataset('data', nb_words=20000, test_split=0.2) + +Seq_Len = 200 +vocab_size = len(X_train) + 1 + + +class imdbdataset(Dataset): + + def __init__(self, X, y): + self.X = X + self.y = y + + def __getitem__(self, index): + + data = self.X[index] + data = np.concatenate([data[:Seq_Len], [0] * (Seq_Len - len(data))]).astype('int64') # set + label = self.y[index].astype('int64') + return data, label + + def __len__(self): + + return len(self.y) + + +class ImdbNet(Module): + + def __init__(self): + super(ImdbNet, self).__init__() + self.embedding = Embedding(vocabulary_size=vocab_size, embedding_size=64) + self.lstm = LSTM(input_size=64, hidden_size=64) + self.dense1 = Dense(in_channels=64, n_units=64, act=tl.ReLU) + self.dense2 = Dense(in_channels=64, n_units=2) + + def forward(self, x): + x = self.embedding(x) + x, _ = self.lstm(x) + x = tl.ops.reduce_mean(x, axis=1) + x = self.dense1(x) + x = self.dense2(x) + return x + + +n_epoch = 5 +batch_size = 64 +print_freq = 2 + +train_dataset = imdbdataset(X=X_train, y=y_train) +train_dataset = tl.dataflow.FromGenerator( + train_dataset, output_types=[tl.int64, tl.int64], column_names=['data', 'label'] +) +train_loader = tl.dataflow.Dataloader(train_dataset, batch_size=batch_size, shuffle=True) + +net = ImdbNet() +train_weights = net.trainable_weights +optimizer = tl.optimizers.Adam(1e-3) +metric = tl.metric.Accuracy() +loss_fn = tl.cost.softmax_cross_entropy_with_logits +model = tl.models.Model(network=net, loss_fn=loss_fn, optimizer=optimizer, metrics=metric) +model.train(n_epoch=n_epoch, train_dataset=train_loader, print_freq=print_freq, print_train_batch=True) diff --git a/tensorlayer/backend/ops/__init__.py b/tensorlayer/backend/ops/__init__.py index ad780e2..04368bd 100644 --- a/tensorlayer/backend/ops/__init__.py +++ b/tensorlayer/backend/ops/__init__.py @@ -141,3 +141,7 @@ from .load_backend import Maximum from .load_backend import Meshgrid from .load_backend import BatchToSpace from .load_backend import DepthToSpace +from .load_backend import rnncell +from .load_backend import lstmcell +from .load_backend import grucell +from .load_backend import rnnbase \ No newline at end of file diff --git a/tensorlayer/backend/ops/mindspore_backend.py b/tensorlayer/backend/ops/mindspore_backend.py index 5e9c9f1..d0bb7c4 100644 --- a/tensorlayer/backend/ops/mindspore_backend.py +++ b/tensorlayer/backend/ops/mindspore_backend.py @@ -720,15 +720,26 @@ def reduce_min(input_tensor, axis=None): class Pad(Cell): - def __init__(self, paddings, mode="REFLECT"): + def __init__(self, paddings, mode="REFLECT", constant_values=0): super(Pad, self).__init__() - if mode not in ["REFLECT", "SYMMETRIC"]: + if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']: raise Exception("Unsupported mode: {}".format(mode)) - self.pad = P.MirrorPad(mode=mode) - self.paddings = Tensor(paddings) + if mode == 'CONSTANT': + self.pad = P.Pad(paddings) + if constant_values-0 == 0: + pass + else: + raise NotImplementedError("constant_values can only be equal to 0.") + else: + self.pad = P.MirrorPad(mode=mode) + self.paddings = Tensor(np.array(self.paddings)) + self.mode = mode def construct(self, x): - return self.pad(x, self.paddings) + if self.mode == 'CONSTANT': + return self.pad(x) + else: + return self.pad(x, self.paddings) def pad(tensor, paddings, mode='CONSTANT', constant_values=0): diff --git a/tensorlayer/backend/ops/mindspore_nn.py b/tensorlayer/backend/ops/mindspore_nn.py index 2d75604..df0dfc2 100644 --- a/tensorlayer/backend/ops/mindspore_nn.py +++ b/tensorlayer/backend/ops/mindspore_nn.py @@ -1800,3 +1800,152 @@ class DorefaConv2D(Cell): outputs = nchw_to_nhwc(outputs) return outputs + + +class rnncell(Cell): + + def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act): + super(rnncell, self).__init__() + self.weight_ih = weight_ih + self.weight_hh = weight_hh + self.bias_ih = bias_ih + self.bias_hh = bias_hh + self.act_fn = P.ReLU() if act == 'relu' else P.Tanh() + self.transpose = P.Transpose() + + def construct(self, input, h): + self.weight_ih = self.transpose(self.weight_ih, (1, 0)) + i2h = P.matmul(input, self.weight_ih) + if self.bias_ih is not None: + i2h += self.bias_ih + self.weight_hh = self.transpose(self.weight_hh, (1, 0)) + h2h = P.matmul(h, self.weight_hh) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self.act_fn(i2h + h2h) + return h, h + + +class lstmcell(Cell): + + def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh): + super(lstmcell, self).__init__() + self.weight_ih = weight_ih + self.weight_hh = weight_hh + self.bias_ih = bias_ih + self.bias_hh = bias_hh + self.gate_act_fn = P.Sigmoid() + self.act_fn = P.Tanh() + self.transpose = P.Transpose() + self.split = P.Split(axis=-1, output_num=4) + + def construct(self, input, h, c): + + self.weight_ih = self.transpose(self.weight_ih, (1, 0)) + gates = P.matmul(input, self.weight_ih) + if self.bias_ih is not None: + gates += self.bias_ih + self.weight_hh = self.transpose(self.weight_hh, (1, 0)) + gates += P.matmul(h, self.weight_hh) + if self.bias_hh is not None: + gates += self.bias_hh + + gate_slices = self.split(gates) + i = self.gate_act_fn(gate_slices[0]) + f = self.gate_act_fn(gate_slices[1]) + o = self.gate_act_fn(gate_slices[3]) + c = f * c + i * self.act_fn(gate_slices[2]) + h = o * self.act_fn(c) + + return h, h, c + + +class grucell(Cell): + + def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh): + super(grucell, self).__init__() + self.weight_ih = weight_ih + self.weight_hh = weight_hh + self.bias_ih = bias_ih + self.bias_hh = bias_hh + self.gate_act_fn = P.Sigmoid() + self.act_fn = P.Tanh() + self.transpose = P.Transpose() + self.split = P.Split(axis=-1, output_num=3) + + def construct(self, input, h): + + self.weight_ih = self.transpose(self.weight_ih, (1, 0)) + x_gates = P.matmul(input, self.weight_ih) + if self.bias_ih is not None: + x_gates += self.bias_ih + self.weight_hh = self.transpose(self.weight_hh, (1, 0)) + h_gates = P.matmul(h, self.weight_hh) + if self.bias_hh is not None: + h_gates += self.bias_hh + + x_r, x_z, x_c = self.split(x_gates) + h_r, h_z, h_c = self.split(h_gates) + + r = self.gate_act_fn(x_r + h_r) + z = self.gate_act_fn(x_r + h_z) + c = self.act_fn(x_c + r * h_c) + h = (h - c) * z + c + + return h, h + + +class rnnbase(Cell): + + def __init__( + self, + mode, + input_size, + hidden_size, + num_layers, + bias, + batch_first, + dropout, + bidirectional, + is_train, + ): + super(rnnbase, self).__init__() + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bidirect = 2 if bidirectional else 1 + self.batch_first = batch_first + if mode == 'LSTM': + self.lstm = ms.nn.LSTM( + input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=bias, + batch_first=batch_first, dropout=dropout, bidirectional=bidirectional + ) + elif mode == 'GRU': + + raise NotImplementedError + + elif mode == 'RNN_TANH': + + raise NotImplementedError + + elif mode == 'RNN_RELU': + + raise NotImplementedError + + self.zeros = P.Zeros() + + def construct(self, input, states): + input_shape = input.shape + input_dtype = input.dtype + if self.mode == 'LSTM': + if self.batch_first: + batch_size = input_shape[0] + else: + batch_size = input_shape[1] + if states is None: + h = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype) + c = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype) + states = (h, c) + output, (h, c) = self.lstm(input, states) + return output, (h, c) diff --git a/tensorlayer/backend/ops/paddle_backend.py b/tensorlayer/backend/ops/paddle_backend.py index 3573bd2..918cadd 100644 --- a/tensorlayer/backend/ops/paddle_backend.py +++ b/tensorlayer/backend/ops/paddle_backend.py @@ -4,6 +4,7 @@ from __future__ import absolute_import, division, print_function import paddle as pd import paddle.nn as nn +import numpy as np _dtypeDict = ["float16", "float32", "float64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"] # TODO NotImplemented @@ -325,7 +326,7 @@ class Reshape(object): self.shape = shape def __call__(self, tensor): - raise NotImplementedError + return pd.reshape(tensor, shape=self.shape) def reshape(tensor, shape): @@ -352,7 +353,7 @@ class Concat(object): self.axis = axis def __call__(self, values): - raise NotImplementedError + return pd.concat(values, axis=self.axis) def concat(values, axis): @@ -369,7 +370,7 @@ def concat(values, axis): ------- A Tensor resulting from concatenation of the input tensors. """ - raise NotImplementedError + return pd.concat(values, axis) def convert_to_tensor(value, dtype=float32): @@ -407,16 +408,16 @@ def sqrt(x): ------- A Tensor. Has the same type as x. """ - raise NotImplementedError + return pd.sqrt(x) class ReduceSum(object): def __init__(self, axis): - pass + self.axis = axis def construct(self, input): - pass + return pd.sum(input, axis=self.axis) class ReduceMean(object): @@ -447,7 +448,7 @@ def reduce_mean(input_tensor, axis=None): The reduced tensor. """ - raise NotImplementedError + return pd.mean(input_tensor, axis) class ReduceMax(object): @@ -478,7 +479,7 @@ def reduce_max(input_tensor, axis=None): The reduced tensor. """ - raise NotImplementedError + return pd.max(input_tensor, axis) def reduce_min(input_tensor, axis=None): @@ -499,21 +500,47 @@ def reduce_min(input_tensor, axis=None): ------- The reduced tensor. """ - raise NotImplementedError + return pd.min(input_tensor, axis) class Pad(object): - def __init__(self, paddings, mode="REFLECT"): + def __init__(self, paddings, mode="REFLECT", constant_values=0): if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']: raise Exception("Unsupported mode: {}".format(mode)) if mode == 'SYMMETRIC': - mode = 'EDGE' + raise NotImplementedError self.paddings = paddings - self.mode = mode + self.mode = mode.lower() + self.constant_values = constant_values def __call__(self, x): - raise NotImplementedError + if len(x.shape) == 3: + data_format = 'NLC' + self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format) + elif len(x.shape) == 4: + data_format = 'NHWC' + self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format) + elif len(x.shape) == 5: + data_format = 'NDHWC' + self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format) + else: + raise NotImplementedError('Please check the input shape.') + return pd.nn.functional.pad(x, self.paddings, self.mode, value=self.constant_values, data_format=data_format) + + def correct_paddings(self, in_shape, paddings, data_format): + if in_shape == 3 and data_format == 'NLC': + correct_output = [paddings[1][0], paddings[1][1]] + elif in_shape == 4 and data_format == 'NHWC': + correct_output = [paddings[2][0], paddings[2][1], + paddings[1][0], paddings[1][1]] + elif in_shape == 5 and data_format == 'NDHWC': + correct_output = [paddings[3][0], paddings[3][1], + paddings[2][0], paddings[2][1], + paddings[1][0], paddings[1][1]] + else: + raise NotImplementedError('Does not support channels first') + return correct_output def pad(tensor, paddings, mode='CONSTANT', constant_values=0): @@ -535,7 +562,7 @@ def pad(tensor, paddings, mode='CONSTANT', constant_values=0): ------- A Tensor. Has the same type as tensor. """ - raise NotImplementedError + return Pad(paddings, mode, constant_values)(tensor) class Unstack(object): @@ -545,7 +572,7 @@ class Unstack(object): self.num = num def __call__(self, values): - raise NotImplementedError + return pd.unstack(values, self.axis, self.num) class Stack(object): @@ -554,7 +581,7 @@ class Stack(object): self.axis = axis def __call__(self, values): - raise NotImplementedError + return pd.stack(values, self.axis) def stack(values, axis=0): @@ -563,7 +590,7 @@ def stack(values, axis=0): Parameters ---------- - values : list + values : list or tuple A list of Tensor objects with the same shape and type. axis : int An int. The axis to stack along. Defaults to the first dimension. @@ -573,7 +600,7 @@ def stack(values, axis=0): ------- A stacked Tensor with the same type as values. """ - raise NotImplementedError + return pd.stack(values, axis=axis) class Meshgrid(object): @@ -583,10 +610,10 @@ class Meshgrid(object): self.index = indexing def __call__(self, inputs): - pass + return pd.meshgrid(inputs) -def meshgrid(x, y): +def meshgrid(*args, **kwargs): """ Broadcasts parameters for evaluation on an N-D grid. @@ -602,7 +629,7 @@ def meshgrid(x, y): A list of N Tensors with rank N. """ - pass + return pd.meshgrid(*args, **kwargs) def range(start, limit=None, delta=1, dtype=None): @@ -626,16 +653,19 @@ def range(start, limit=None, delta=1, dtype=None): ------- An 1-D Tensor of type dtype. """ - raise NotImplementedError + return pd.arange(start, step=delta) class ExpandDims(object): def __init__(self, axis): - pass + self.axis = axis def construct(self, input): - pass + input = convert_to_numpy(input) + output = np.expand_dims(input, axis=self.axis) + output = convert_to_tensor(output) + return output def expand_dims(input, axis): @@ -655,7 +685,10 @@ def expand_dims(input, axis): A Tensor with the same data as input, but its shape has an additional dimension of size 1 added. """ - raise NotImplementedError + input = convert_to_numpy(input) + output = np.expand_dims(input, axis=axis) + output = convert_to_tensor(output) + return output class Tile(object): @@ -664,7 +697,7 @@ class Tile(object): pass def __call__(self, input, multiples): - raise NotImplementedError + return pd.tile(input, multiples) def tile(input, multiples): @@ -683,16 +716,16 @@ def tile(input, multiples): ------- A Tensor. Has the same type as input. """ - raise NotImplementedError + return pd.tile(input, multiples) class Cast(object): def __init__(self, dtype): - pass + self.dtype = dtype def __call__(self, input): - pass + return pd.cast(input, self.dtype) def cast(x, dtype): @@ -711,7 +744,7 @@ def cast(x, dtype): ------- A Tensor or SparseTensor or IndexedSlices with same shape as x and same type as dtype. """ - raise NotImplementedError + return pd.cast(x, dtype) class Transpose(object): @@ -722,7 +755,7 @@ class Transpose(object): raise ("The conjugate Parameters not supported") def __call__(self, a): - raise NotImplementedError + return pd.transpose(a, self.perm) def transpose(a, perm=None, conjugate=False): @@ -743,7 +776,7 @@ def transpose(a, perm=None, conjugate=False): A transposed Tensor. """ - raise NotImplementedError + return pd.transpose(a, perm) def gather_nd(params, indices, batch_dims=0): @@ -764,7 +797,7 @@ def gather_nd(params, indices, batch_dims=0): A Tensor. Has the same type as params. """ - pass + return pd.gather_nd(params, indices) def clip_by_value(t, clip_value_min, clip_value_max): @@ -785,7 +818,7 @@ def clip_by_value(t, clip_value_min, clip_value_max): A clipped Tensor or IndexedSlices. """ - pass + return pd.clip(t, clip_value_min, clip_value_max) def split(value, num_or_size_splits, axis=0, num=None): @@ -796,7 +829,7 @@ def split(value, num_or_size_splits, axis=0, num=None): ---------- value : tensor The Tensor to split. - num_or_size_splits : list + num_or_size_splits : list or tuple Either an integer indicating the number of splits along split_dim or a 1-D integer Tensor or Python list containing the sizes of each output tensor along split_dim. axis : int @@ -808,33 +841,33 @@ def split(value, num_or_size_splits, axis=0, num=None): ------- Tensor objects resulting from splitting value. """ - pass + pd.split(value, num_or_size_splits, axis) class Floor(object): - def __call__(self, *args, **kwargs): - raise NotImplementedError + def __call__(self, x): + return pd.floor(x) def floor(x): - raise NotImplementedError + return pd.floor(x) def gather(params, indices): - raise NotImplementedError + return pd.gather(params, indices) def linspace(start, stop, num): - raise NotImplementedError + return pd.linspace(start, stop, num) def slice(inputs, starts, sizes): - raise NotImplementedError + return pd.slice(inputs, starts=starts, ends=sizes) def add_n(inputs): - raise NotImplementedError + return pd.add_n(inputs) class OneHot(object): @@ -844,17 +877,19 @@ class OneHot(object): self.dtype = dtype def __call__(self, indices): - raise NotImplementedError + output = pd.nn.functional.one_hot(indices, self.depth) + return output class L2Normalize(object): def __init__(self, axis=None, epsilon=1e-12): super(L2Normalize, self).__init__() - pass + self.axis = axis + self.epsilon = epsilon - def __call__(self, input, *args, **kwargs): - pass + def __call__(self, input): + return pd.nn.functional.normalize(x=input, p=2, axis=self.axis, epsilon=self.epsilon) class EmbeddingLookup(object): @@ -862,7 +897,7 @@ class EmbeddingLookup(object): def __init__(self, max_norm=None): self.max_norm = max_norm - def __call__(self, params, ids, *args, **kwargs): + def __call__(self, params, ids): pass diff --git a/tensorlayer/backend/ops/paddle_nn.py b/tensorlayer/backend/ops/paddle_nn.py index a2fe790..a66c08f 100644 --- a/tensorlayer/backend/ops/paddle_nn.py +++ b/tensorlayer/backend/ops/paddle_nn.py @@ -3,6 +3,12 @@ import paddle as pd import paddle.nn.functional as F +import numpy as np +import paddle.fluid as fluid +from paddle.nn import initializer as I +from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as +from paddle.fluid.data_feeder import convert_dtype +from paddle.fluid.dygraph import Layer def padding_format(padding): @@ -1308,3 +1314,386 @@ class DorefaConv2D(object): def __call__(self, inputs, filters): raise NotImplementedError + + +class rnncell(object): + + def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, bias, act): + self.weight_ih = weight_ih + self.weight_hh = weight_hh + self.bias_ih = bias_ih + self.bias_hh = bias_hh + self.bias = bias + self.act_fn = F.relu if act == 'relu' else F.tanh + + def __call__(self, input, h): + + i2h = pd.matmul(input, self.weight_ih, transpose_y=True) + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = pd.matmul(h, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self.act_fn(i2h + h2h) + return h, h + + +class lstmcell(object): + + def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, bias): + self.weight_ih = weight_ih + self.weight_hh = weight_hh + self.bias_ih = bias_ih + self.bias_hh = bias_hh + self.bias = bias + self.gate_act_fn = F.sigmoid + self.act_fn = F.tanh + + def __call__(self, inputs, h, c): + + gates = pd.matmul(inputs, self.weight_ih, transpose_y=True) + if self.bias_ih is not None: + gates += self.bias_ih + gates += pd.matmul(h, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + gates += self.bias_hh + + gates_slices = pd.split(gates, num_or_sections=4, axis=-1) + + i = self.gate_act_fn(gates_slices[0]) + f = self.gate_act_fn(gates_slices[1]) + o = self.gate_act_fn(gates_slices[3]) + c = f * c + i * self.act_fn(gates_slices[2]) + h = o * self.act_fn(c) + + return h, h, c + + +class grucell(object): + + def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, bias): + self.weight_ih = weight_ih + self.weight_hh = weight_hh + self.bias_ih = bias_ih + self.bias_hh = bias_hh + self.bias = bias + self.gate_act_fn = F.sigmoid + self.act_fn = F.tanh + + def __call__(self, input, h): + + x_gates = pd.matmul(input, self.weight_ih, transpose_y=True) + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + h_gates = pd.matmul(h, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h_gates = h_gates + self.bias_hh + + x_r, x_z, x_c = pd.split(x_gates, num_or_sections=3, axis=-1) + h_r, h_z, h_c = pd.split(h_gates, num_or_sections=3, axis=-1) + + r = self.gate_act_fn(x_r + h_r) + z = self.gate_act_fn(x_z + h_z) + c = self.act_fn(x_c + r * h_c) # apply reset gate after mm + h = (h - c) * z + c + + return h, h + + +def split_states(states, bidirectional=False, state_components=1): + r""" + Split states of RNN network into possibly nested list or tuple of + states of each RNN cells of the RNN network. + + Parameters: + states (Tensor|tuple|list): the concatenated states for RNN network. + When `state_components` is 1, states in a Tensor with shape + `(L*D, N, C)` where `L` is the number of layers of the RNN + network, `D` is the number of directions of the RNN network(1 + for unidirectional RNNs and 2 for bidirectional RNNs), `N` is + the batch size of the input to the RNN network, `C` is the + hidden size of the RNN network. + + When `state_components` is larger than 1, `states` is a tuple of + `state_components` Tensors that meet the requirements described + above. + + For SimpleRNNs and GRUs, `state_components` is 1, and for LSTMs, + `state_components` is 2. + bidirectional (bool): whether the state is of a bidirectional RNN + network. Defaults to False. + state_components (int): the number of the components of the states. see + `states` above. Defaults to 1. + + Returns: + A nested list or tuple of RNN cell states. + If `bidirectional` is True, it can be indexed twice to get an RNN + cell state. The first index indicates the layer, the second index + indicates the direction. + If `bidirectional` is False, it can be indexed once to get an RNN + cell state. The index indicates the layer. + Note that if `state_components` is larger than 1, an RNN cell state + can be indexed one more time to get a tensor of shape(N, C), where + `N` is the batch size of the input to the RNN cell, and `C` is the + hidden size of the RNN cell. + """ + if state_components == 1: + states = pd.unstack(states) + if not bidirectional: + return states + else: + return list(zip(states[::2], states[1::2])) + else: + assert len(states) == state_components + states = tuple([pd.unstack(item) for item in states]) + if not bidirectional: + return list(zip(*states)) + else: + states = list(zip(*states)) + return list(zip(states[::2], states[1::2])) + + +def concat_states(states, bidirectional=False, state_components=1): + r""" + Concatenate a possibly nested list or tuple of RNN cell states into a + compact form. + + Parameters: + states (list|tuple): a possibly nested list or tuple of RNN cell + states. + If `bidirectional` is True, it can be indexed twice to get an + RNN cell state. The first index indicates the layer, the second + index indicates the direction. + If `bidirectional` is False, it can be indexed once to get an RNN + cell state. The index indicates the layer. + Note that if `state_components` is larger than 1, an RNN cell + state can be indexed one more time to get a tensor of shape(N, C), + where `N` is the batch size of the input to the RNN cell, and + `C` is the hidden size of the RNN cell. + bidirectional (bool): whether the state is of a bidirectional RNN + network. Defaults to False. + state_components (int): the number of the components of the states. see + `states` above. Defaults to 1. + + Returns: + Concatenated states for RNN network. + When `state_components` is 1, states in a Tensor with shape + `(L\*D, N, C)` where `L` is the number of layers of the RNN + network, `D` is the number of directions of the RNN network(1 for + unidirectional RNNs and 2 for bidirectional RNNs), `N` is the batch + size of the input to the RNN network, `C` is the hidden size of the + RNN network. + + """ + if state_components == 1: + return pd.stack(flatten(states)) + else: + states = flatten(states) + componnets = [] + for i in range(state_components): + componnets.append(states[i::state_components]) + return tuple([pd.stack(item) for item in componnets]) + + +class rnnbase(Layer): + + def __init__( + self, + mode, + input_size, + hidden_size, + num_layers, + bias, + batch_first, + dropout, + bidirectional, + is_train, + ): + super(rnnbase, self).__init__() + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.time_major = False if batch_first else True + self.dropout = dropout + self.bidirect = 2 if bidirectional else 1 + self.state_components = 2 if mode == 'LSTM' else 1 + self.rnn = pd.nn.LayerList() + RNN = pd.nn.RNN + BiRNN = pd.nn.BiRNN + weight_ih_attr = None + weight_hh_attr = None + if bias: + bias_ih_attr = None + bias_hh_attr = None + else: + bias_ih_attr = False + bias_hh_attr = False + + kwargs = { + "weight_ih_attr": weight_ih_attr, + "weight_hh_attr": weight_hh_attr, + "bias_ih_attr": bias_ih_attr, + "bias_hh_attr": bias_hh_attr + } + + if mode == "LSTM": + rnn_cls = pd.nn.LSTMCell + elif mode == "GRU": + rnn_cls = pd.nn.GRUCell + elif mode == 'RNN_TANH': + rnn_cls = pd.nn.SimpleRNNCell + kwargs["activation"] = 'tanh' + elif mode == 'RNN_RELU': + rnn_cls = pd.nn.SimpleRNNCell + kwargs["activation"] = 'relu' + + if not bidirectional: + is_reverse = False + cell = rnn_cls(input_size, hidden_size, **kwargs) + self.rnn.append(RNN(cell, is_reverse, self.time_major)) + for i in range(1, num_layers): + cell = rnn_cls(hidden_size, hidden_size, **kwargs) + self.rnn.append(RNN(cell, is_reverse, self.time_major)) + else: + cell_fw = rnn_cls(input_size, hidden_size, **kwargs) + cell_bw = rnn_cls(input_size, hidden_size, **kwargs) + self.rnn.append(BiRNN(cell_fw, cell_bw, self.time_major)) + for i in range(1, num_layers): + cell_fw = rnn_cls(2 * hidden_size, hidden_size, **kwargs) + cell_bw = rnn_cls(2 * hidden_size, hidden_size, **kwargs) + self.rnn.append(BiRNN(cell_fw, cell_bw, self.time_major)) + + self.could_use_cudnn = True + self.could_use_cudnn &= len(self.rnn.parameters()) == num_layers * 4 * self.bidirect + + param_names = [] + for layer in range(self.num_layers): + for direction in range(self.bidirect): + suffix = '_reverse' if direction == 1 else '' + param_names.extend(['weight_ih_l{}{}', 'weight_hh_l{}{}']) + if bias_ih_attr != False: param_names.append('bias_ih_l{}{}') + if bias_hh_attr != False: param_names.append('bias_hh_l{}{}') + param_names = [x.format(layer, suffix) for x in param_names] + for name, param in zip(param_names, self.rnn.parameters()): + setattr(self.rnn, name, param) + + self.flatten_parameters() + + def flatten_parameters(self): + """ + Resets parameter data pointer to address in continuous memory block for + cudnn usage. + """ + if self.could_use_cudnn: + # layer.parameters() is depth first and ordered + # for i in layer: for j in direct: w_ih, w_hh, b_ih, b_hh + # need to reorganize to cudnn param layout: + # all bias following all weights + params = self.rnn.parameters(include_sublayers=False) + shape = [np.prod(param.shape) for param in params] + self._all_weights = [None] * len(params) + for i, param in enumerate(params): + offset = 0 if i % 4 < 2 else (2 * self.num_layers * self.bidirect) + layer_idx = i // 4 + self._all_weights[offset + layer_idx * 2 + i % 2] = param + + # Wrap using a list to avoid registed into params and saving, maybe + # need a better way to handle this later. Use `create_parameter` to + # add both to main_program and startup_program for static-graph. + # Use Constant initializer to avoid make effect on random generator. + self._flat_weight = [ + self.rnn.create_parameter( + shape=[np.sum(shape)], dtype=params[0].dtype, default_initializer=I.Constant(0.0) + ) + ] + + # dropout state may also can be hided and avoid saving + # should dropout state be persistable for static-graph + self._dropout_state = self.rnn.create_variable(dtype=fluid.core.VarDesc.VarType.UINT8) + # for static-graph, append coalesce_tensor into startup program + with fluid.program_guard(fluid.default_startup_program(), fluid.default_startup_program()): + with pd.framework.no_grad(): + self.rnn._helper.append_op( + type="coalesce_tensor", inputs={"Input": self._all_weights}, outputs={ + "Output": self._all_weights, + "FusedOutput": self._flat_weight + }, attrs={ + "copy_data": True, + "use_align": False, + "dtype": params[0].dtype + } + ) + + def _cudnn_impl(self, inputs, initial_states, sequence_length): + if not self.time_major: + inputs = pd.tensor.transpose(inputs, [1, 0, 2]) + out = self.rnn._helper.create_variable_for_type_inference(inputs.dtype) + state = [ + self.rnn._helper.create_variable_for_type_inference(inputs.dtype) for i in range(self.state_components) + ] + reserve = self.rnn._helper.create_variable_for_type_inference( + dtype=fluid.core.VarDesc.VarType.UINT8, stop_gradient=True + ) + + inputs = { + 'Input': inputs, + 'WeightList': self._all_weights, + 'PreState': initial_states, + 'SequenceLength': sequence_length + } + attrs = { + 'dropout_prob': self.dropout, + 'is_bidirec': self.bidirect == 2, + 'input_size': self.input_size, + 'hidden_size': self.hidden_size, + 'num_layers': self.num_layers, + 'mode': self.mode, + 'is_test': not self.rnn.training + } + + outputs = { + 'Out': out, + 'State': state, + 'Reserve': reserve, + 'DropoutState': self._dropout_state, + } + + self.rnn._helper.append_op(type="rnn", inputs=inputs, outputs=outputs, attrs=attrs) + out = pd.tensor.transpose(out, [1, 0, 2]) if not self.time_major else out + return out, tuple(state) if len(state) > 1 else state[0] + + def forward(self, inputs, initial_states=None, sequence_length=None): + batch_index = 1 if self.time_major else 0 + dtype = inputs.dtype + if initial_states is None: + state_shape = [self.num_layers * self.bidirect, -1, self.hidden_size] + if self.state_components == 1: + initial_states = fluid.layers.fill_constant_batch_size_like( + inputs, state_shape, dtype, 0, batch_index, 1 + ) + else: + initial_states = tuple( + [ + fluid.layers.fill_constant_batch_size_like(inputs, state_shape, dtype, 0, batch_index, 1) + for _ in range(self.state_components) + ] + ) + + if self.could_use_cudnn: + # Add CPU kernel and dispatch in backend later + return self._cudnn_impl(inputs, initial_states, sequence_length) + + states = split_states(initial_states, self.bidirect == 2, self.state_components) + + final_states = [] + + for i, rnn_layer in enumerate(self.rnn): + if i > 0: + inputs = F.dropout(inputs, self.dropout, training=self.rnn.training, mode="upscale_in_train") + outputs, final_state = rnn_layer(inputs, states[i], sequence_length) + final_states.append(final_state) + inputs = outputs + + final_states = concat_states(final_states, self.bidirect == 2, self.state_components) + return outputs, final_states diff --git a/tensorlayer/backend/ops/tensorflow_backend.py b/tensorlayer/backend/ops/tensorflow_backend.py index 99a0b31..6b0009d 100644 --- a/tensorlayer/backend/ops/tensorflow_backend.py +++ b/tensorlayer/backend/ops/tensorflow_backend.py @@ -531,14 +531,15 @@ def reduce_min(input_tensor, axis=None): class Pad(object): - def __init__(self, paddings, mode="REFLECT"): + def __init__(self, paddings, mode="REFLECT", constant_values=0): if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']: raise Exception("Unsupported mode: {}".format(mode)) self.paddings = paddings self.mode = mode + self.constant_values = constant_values def __call__(self, x): - outputs = tf.pad(x, self.paddings, mode=self.mode, constant_values=0) + outputs = tf.pad(x, self.paddings, mode=self.mode, constant_values=self.constant_values) return outputs @@ -884,7 +885,7 @@ class OneHot(object): self.axis = axis self.dtype = dtype - def __call__(self, inputs, *args, **kwargs): + def __call__(self, inputs): outputs = tf.one_hot( inputs, self.depth, on_value=self.on_value, off_value=self.off_value, axis=self.axis, dtype=self.dtype ) @@ -907,7 +908,7 @@ class EmbeddingLookup(object): def __init__(self, max_norm=None): self.max_norm = max_norm - def __call__(self, params, ids, *args, **kwargs): + def __call__(self, params, ids): outputs = tf.nn.embedding_lookup(params=params, ids=ids, max_norm=self.max_norm) return outputs diff --git a/tensorlayer/backend/ops/tensorflow_nn.py b/tensorlayer/backend/ops/tensorflow_nn.py index d8b2d73..6359a0f 100644 --- a/tensorlayer/backend/ops/tensorflow_nn.py +++ b/tensorlayer/backend/ops/tensorflow_nn.py @@ -6,6 +6,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.training import moving_averages from math import floor, ceil +import numpy as np # loss function sparse_softmax_cross_entropy_with_logits = tf.nn.sparse_softmax_cross_entropy_with_logits sigmoid_cross_entropy_with_logits = tf.nn.sigmoid_cross_entropy_with_logits @@ -1913,3 +1914,342 @@ class DorefaConv2D(object): ) return outputs + + +class rnncell(object): + + def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act): + self.weight_ih = weight_ih + self.weight_hh = weight_hh + self.bias_ih = bias_ih + self.bias_hh = bias_hh + self.act_fn = tf.nn.relu if act == 'relu' else tf.nn.tanh + + def __call__(self, input, h, c=None): + + i2h = tf.matmul(input, self.weight_ih, transpose_b=True) + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = tf.matmul(h, self.weight_hh, transpose_b=True) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self.act_fn(i2h + h2h) + return h, h + + +class lstmcell(object): + + def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act=None): + self.weight_ih = weight_ih + self.weight_hh = weight_hh + self.bias_ih = bias_ih + self.bias_hh = bias_hh + self.gate_act_fn = tf.sigmoid + self.act_fn = tf.tanh + + def __call__(self, input, h, c): + + gates = tf.matmul(input, self.weight_ih, transpose_b=True) + if self.bias_ih is not None: + gates = gates + self.bias_ih + gates += tf.matmul(h, self.weight_hh, transpose_b=True) + if self.bias_hh is not None: + gates += self.bias_hh + + gate_slices = tf.split(gates, num_or_size_splits=4, axis=-1) + i = self.gate_act_fn(gate_slices[0]) + f = self.gate_act_fn(gate_slices[1]) + o = self.gate_act_fn(gate_slices[3]) + c = f * c + i * self.act_fn(gate_slices[2]) + h = o * self.act_fn(c) + + return h, h, c + + +class grucell(object): + + def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act=None): + self.weight_ih = weight_ih + self.weight_hh = weight_hh + self.bias_ih = bias_ih + self.bias_hh = bias_hh + self.gate_act_fn = tf.sigmoid + self.act_fn = tf.tanh + + def __call__(self, input, h, c=None): + + x_gates = tf.matmul(input, self.weight_ih, transpose_b=True) + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + h_gates = tf.matmul(h, self.weight_hh, transpose_b=True) + if self.bias_hh is not None: + h_gates = h_gates + self.bias_hh + + x_r, x_z, x_c = tf.split(x_gates, num_or_size_splits=3, axis=-1) + h_r, h_z, h_c = tf.split(h_gates, num_or_size_splits=3, axis=-1) + + r = self.gate_act_fn(x_r + h_r) + z = self.gate_act_fn(x_r + h_z) + c = self.act_fn(x_c + r * h_c) + h = (h - c) * z + c + + return h, h + + +class rnnbase(object): + + def __init__( + self, + mode, + input_size, + hidden_size, + num_layers, + bias, + batch_first, + dropout, + bidirectional, + is_train, + weights_fw, + weights_bw, + bias_fw, + bias_bw, + ): + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.train = is_train + if not 0 <= dropout < 1: + raise ValueError("dropout should be a number in range [0, 1).") + if dropout > 0 and num_layers == 1: + raise ValueError( + "dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + "num_layers greater than 1, but got dropout={} and " + "num_layers={}".format(dropout, num_layers) + ) + self.bidirect = 2 if bidirectional else 1 + + self.weights_fw = weights_fw + self.bias_fw = bias_fw + self.weights_bw = weights_bw + self.bias_bw = bias_bw + + # stdv = 1.0 / np.sqrt(self.hidden_size) + # _init = tf.random_uniform_initializer(minval=-stdv, maxval=stdv) + + self.act_fn = None + if mode == 'LSTM': + # gate_size = 4 * hidden_size + self.rnn_cell = lstmcell + elif mode == 'GRU': + # gate_size = 3 * hidden_size + self.rnn_cell = grucell + elif mode == 'RNN_TANH': + # gate_size = hidden_size + self.rnn_cell = rnncell + self.act_fn = 'tanh' + elif mode == 'RNN_RELU': + # gate_size = hidden_size + self.rnn_cell = rnncell + self.act_fn = 'relu' + + # for layer in range(num_layers): + # for direction in range(self.bidirect): + # layer_input_size = input_size if layer==0 else hidden_size*self.bidirect + # if direction == 0: + # self.w_ih = tf.Variable(initial_value= _init(shape=(gate_size, layer_input_size)),name = 'weight_ih_l'+str(layer), trainable=True) + # self.w_hh = tf.Variable(initial_value=_init(shape=(gate_size, hidden_size)), + # name='weight_hh_l'+str(layer), trainable=True) + # # self.w_ih = self.weights_init('weight_ih_l'+str(layer), shape = (gate_size, layer_input_size), init = _init) + # # self.w_hh = self.weights_init('weight_ih_l' + str(layer), shape=(gate_size, hidden_size), + # # init=_init) + # self.weights_fw.append(self.w_ih) + # self.weights_fw.append(self.w_hh) + # if bias: + # self.b_ih = tf.Variable(initial_value=_init(shape=(gate_size,)), + # name='bias_ih_l'+str(layer), trainable=True) + # self.b_hh = tf.Variable(initial_value=_init(shape=(gate_size,)), + # name='bias_hh_l'+str(layer), trainable=True) + # # self.b_ih = self.weights_init('bias_ih_l'+str(layer), shape=(gate_size,), init=_init) + # # self.b_hh = self.weights_init('bias_hh_l'+str(layer), shape=(gate_size,), init=_init) + # self.bias_fw.append(self.b_ih) + # self.bias_fw.append(self.b_hh) + # else: + # self.w_ih = tf.Variable(initial_value= _init(shape=(gate_size, layer_input_size)),name = 'weight_ih_l'+str(layer)+'_reverse', trainable=True) + # self.w_hh = tf.Variable(initial_value=_init(shape=(gate_size, hidden_size)), + # name='weight_hh_l'+str(layer)+'_reverse', trainable=True) + # # self.w_ih = self.weights_init('weight_ih_l'+str(layer)+'_reverse', shape = (gate_size, layer_input_size), init = _init) + # # self.w_hh = self.weights_init('weight_hh_l'+str(layer)+'_reverse', shape=(gate_size, hidden_size), + # # init=_init) + # self.weights_bw.append(self.w_ih) + # self.weights_bw.append(self.w_hh) + # if bias: + # self.b_ih = tf.Variable(initial_value=_init(shape=(gate_size,)), + # name='bias_ih_l'+str(layer)+'_reverse', trainable=True) + # self.b_hh = tf.Variable(initial_value=_init(shape=(gate_size,)), + # name='bias_hh_l'+str(layer)+'_reverse', trainable=True) + # # self.b_ih = self.weights_init('bias_ih_l'+str(layer)+'_reverse', shape=(gate_size,), init=_init) + # # self.b_hh = self.weights_init('bias_hh_l'+str(layer)+'_reverse', shape=(gate_size,), init=_init) + # self.bias_bw.append(self.b_ih) + # self.bias_bw.append(self.b_hh) + + def _bi_rnn_forward(self, x, h, c=None): + time_step, batch_size, input_size = x.shape + h_out = [] + c_out = [] + y = [] + pre_layer = x + for i in range(self.num_layers): + weight_ih_fw = self.weights_fw[2 * i] + weight_hh_fw = self.weights_fw[2 * i + 1] + weight_ih_bw = self.weights_bw[2 * i] + weight_hh_bw = self.weights_bw[2 * i + 1] + if self.bias: + bias_ih_fw = self.bias_fw[2 * i] + bias_hh_fw = self.bias_fw[2 * i + 1] + bias_ih_bw = self.bias_bw[2 * i] + bias_hh_bw = self.bias_bw[2 * i + 1] + else: + bias_ih_fw = None + bias_hh_fw = None + bias_ih_bw = None + bias_hh_bw = None + h_i_fw = h[i, :, :] + h_i_bw = h[i + 1, :, :] + if i != 0 and self.train: + pre_layer = tf.nn.dropout(pre_layer, rate=self.dropout) + if c is not None: + c_i_fw = c[i, :, :] + c_i_bw = c[i + 1, :, :] + for j in range(time_step): + input = pre_layer[j, :, :] + cell_fw = self.rnn_cell(weight_ih_fw, weight_hh_fw, bias_ih_fw, bias_hh_fw, self.act_fn) + cell_bw = self.rnn_cell(weight_ih_bw, weight_hh_bw, bias_ih_bw, bias_hh_bw, self.act_fn) + bw_input = tf.reverse(input, axis=[0]) + step_out_fw, h_i_fw, c_i_fw = cell_fw(input, h_i_fw, c_i_fw) + step_out_bw, h_i_bw, c_i_bw = cell_bw(bw_input, h_i_bw, c_i_bw) + step_out_bw = tf.reverse(step_out_bw, axis=[0]) + step_out = tf.concat([step_out_fw, step_out_bw], axis=-1) + y.append(step_out) + h_out.append(h_i_fw) + h_out.append(h_i_bw) + c_out.append(c_i_fw) + c_out.append(c_i_bw) + pre_layer = tf.stack(y) + y = [] + else: + for j in range(time_step): + input = pre_layer[j, :, :] + cell_fw = self.rnn_cell(weight_ih_fw, weight_hh_fw, bias_ih_fw, bias_hh_fw, self.act_fn) + cell_bw = self.rnn_cell(weight_ih_bw, weight_hh_bw, bias_ih_bw, bias_hh_bw, self.act_fn) + bw_input = tf.reverse(input, axis=[0]) + step_out_fw, h_i_fw = cell_fw(input, h_i_fw) + step_out_bw, h_i_bw = cell_bw(bw_input, h_i_bw) + step_out_bw = tf.reverse(step_out_bw, axis=[0]) + step_out = tf.concat([step_out_fw, step_out_bw], axis=-1) + y.append(step_out) + h_out.append(h_i_fw) + h_out.append(h_i_bw) + pre_layer = tf.stack(y) + y = [] + h_out = tf.stack(h_out) + c_out = tf.stack(c_out) if c is not None else None + + return pre_layer, h_out, c_out + + def _rnn_forward(self, x, h, c=None): + pre_layer = x + h_out = [] + c_out = [] + y = [] + time_step, batch_size, input_size = x.shape + for i in range(self.num_layers): + weight_ih = self.weights_fw[2 * i] + weight_hh = self.weights_fw[2 * i + 1] + if self.bias: + bias_ih = self.bias_fw[2 * i] + bias_hh = self.bias_fw[2 * i + 1] + else: + bias_ih = None + bias_hh = None + h_i = h[i, :, :] + if i != 0 and self.train: + pre_layer = tf.nn.dropout(pre_layer, rate=self.dropout) + if c is not None: + c_i = c[i, :, :] + for j in range(time_step): + input = pre_layer[j, :, :] + cell = self.rnn_cell(weight_ih, weight_hh, bias_ih, bias_hh, self.act_fn) + step_out, h_i, c_i = cell(input, h_i, c_i) + y.append(step_out) + h_out.append(h_i) + c_out.append(c_i) + pre_layer = tf.stack(y) + y = [] + else: + for j in range(time_step): + input = pre_layer[j, :, :] + cell = self.rnn_cell(weight_hh, weight_ih, bias_ih, bias_hh, self.act_fn) + step_out, h_i = cell(input, h_i) + y.append(step_out) + h_out.append(h_i) + pre_layer = tf.stack(y) + y = [] + h_out = tf.stack(h_out) + c_out = tf.stack(c_out) if c is not None else None + + return pre_layer, h_out, c_out + + def check_input(self, input_shape): + if len(input_shape) != 3: + raise ValueError("input must have 3 dimensions. But got {}.".format(len(input_shape))) + if self.input_size != input_shape[-1]: + raise ValueError( + "The last dimension of input should be equal to input_size {}.But got {}".format( + self.input_size, input_shape[-1] + ) + ) + + def check_hidden(self, h, batch_size): + expected_hidden_size = (self.num_layers * self.bidirect, batch_size, self.hidden_size) + if h.shape != expected_hidden_size: + raise ValueError('Expected hidden size {}, got {}.'.format(expected_hidden_size, h.shape)) + + def __call__(self, input, states): + if self.batch_first: + input = tf.transpose(input, perm=(1, 0, 2)) + input_dtype = input.dtype + input_shape = input.shape + time_step, batch_size, input_size = input_shape + self.check_input(input_shape) + if self.mode == "LSTM": + if states is not None: + h, c = states + self.check_hidden(h, batch_size) + self.check_hidden(c, batch_size) + else: + h = tf.zeros(shape=(self.num_layers * self.bidirect, batch_size, self.hidden_size), dtype=input_dtype) + c = tf.zeros(shape=(self.num_layers * self.bidirect, batch_size, self.hidden_size), dtype=input_dtype) + if self.bidirect == 1: + y, new_h, new_c = self._rnn_forward(input, h, c) + else: + y, new_h, new_c = self._bi_rnn_forward(input, h, c) + new_states = (new_h, new_c) + else: + if states is not None: + h = states + self.check_hidden(h, batch_size) + else: + h = tf.zeros(shape=(self.num_layers * self.bidirect, batch_size, self.hidden_size), dtype=input_dtype) + if self.bidirect == 1: + y, new_h, _ = self._rnn_forward(input, h) + else: + y, new_h, _ = self._bi_rnn_forward(input, h) + new_states = new_h + if self.batch_first: + y = tf.transpose(y, perm=(1, 0, 2)) + return y, new_states diff --git a/tensorlayer/files/utils.py b/tensorlayer/files/utils.py index 1350577..0ca30ed 100644 --- a/tensorlayer/files/utils.py +++ b/tensorlayer/files/utils.py @@ -1963,6 +1963,8 @@ def save_npz(save_list=None, name='model.npz'): save_list_var = tf_variables_to_numpy(save_list) elif tl.BACKEND == 'mindspore': save_list_var = ms_variables_to_numpy(save_list) + elif tl.BACKEND == 'paddle': + save_list_var = pd_variables_to_numpy(save_list) else: raise NotImplementedError("This backend is not supported") # print(name, save_list_var) @@ -2050,6 +2052,11 @@ def assign_weights(weights, network): # net = Assign_net(network.all_weights[idx]) # net(assign_param) Assign()(network.all_weights[idx], assign_param) + elif tl.BACKEND == 'paddle': + for idx, param in enumerate(weights): + assign_pd_variable(network.all_weights[idx], param) + else: + raise NotImplementedError ("This backend is not supported") return ops diff --git a/tensorlayer/layers/padding.py b/tensorlayer/layers/padding.py index 84695b7..229baeb 100644 --- a/tensorlayer/layers/padding.py +++ b/tensorlayer/layers/padding.py @@ -41,11 +41,13 @@ class PadLayer(Module): self, padding=None, mode='CONSTANT', + constant_values=0, name=None, # 'pad_layer', ): super().__init__(name) self.padding = padding self.mode = mode + self.constant_values = constant_values logging.info("PadLayer %s: padding: %s mode: %s" % (self.name, self.padding, self.mode)) @@ -65,7 +67,7 @@ class PadLayer(Module): return s.format(classname=self.__class__.__name__, **self.__dict__) def build(self, inputs_shape=None): - self.pad = tl.ops.Pad(paddings=self.padding, mode=self.mode) + self.pad = tl.ops.Pad(paddings=self.padding, mode=self.mode, constant_values=self.constant_values) def forward(self, inputs): outputs = self.pad(inputs) diff --git a/tensorlayer/layers/recurrent.py b/tensorlayer/layers/recurrent.py index 5434cec..a1c3f7d 100644 --- a/tensorlayer/layers/recurrent.py +++ b/tensorlayer/layers/recurrent.py @@ -2,1265 +2,704 @@ # -*- coding: utf-8 -*- import numpy as np -import tensorflow as tf - import tensorlayer as tl from tensorlayer import logging -from tensorlayer.decorators import deprecated_alias +from tensorlayer.backend.ops.load_backend import BACKEND from tensorlayer.layers.core import Module -# TODO: Need to update to version 3.0 __all__ = [ 'RNN', - 'SimpleRNN', - 'GRURNN', - 'LSTMRNN', - 'BiRNN', - # 'ConvRNNCell', - # 'BasicConvLSTMCell', - # 'ConvLSTM', - 'retrieve_seq_length_op', - 'retrieve_seq_length_op2', - 'retrieve_seq_length_op3', - 'target_mask_op', + 'RNNCell', + 'GRU', + 'LSTM', + 'GRUCell', + 'LSTMCell', ] -class RNN(Module): - """ - The :class:`RNN` class is a fixed length recurrent layer for implementing simple RNN, - LSTM, GRU and etc. +class RNNCell(Module): + """An Elman RNN cell with tanh or ReLU non-linearity. Parameters ---------- - cell : TensorFlow cell function - A RNN cell implemented by tf.keras - - E.g. tf.keras.layers.SimpleRNNCell, tf.keras.layers.LSTMCell, tf.keras.layers.GRUCell - - Note TF2.0+, TF1.0+ and TF1.0- are different + input_size : int + The number of expected features in the input `x` + hidden_size : int + The number of features in the hidden state `h` + bias : bool + If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` + act : activation function + The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh' + name : None or str + A unique layer name + -------------------------------------------------------- + inputs : tensor + A tensor with shape `[batch_size, input_size]`. + states : tensor or None + A tensor with shape `[batch_size, hidden_size]`. When states is None, zero state is used. Defaults to None. - return_last_output : boolean - Whether return last output or all outputs in a sequence. + Returns + ---------- + outputs : tensor + A tensor with shape `[batch_size, hidden_size]`. + states : tensor + A tensor with shape `[batch_size, hidden_size]`. + Tensor containing the next hidden state for each element in the batch - - If True, return the last output, "Sequence input and single output" - - If False, return all outputs, "Synced sequence input and output" - - In other word, if you want to stack more RNNs on this layer, set to False - - In a dynamic model, `return_last_output` can be updated when it is called in customised forward(). - By default, `False`. - return_seq_2d : boolean - Only consider this argument when `return_last_output` is `False` - - - If True, return 2D Tensor [batch_size * n_steps, n_hidden], for stacking Dense layer after it. - - If False, return 3D Tensor [batch_size, n_steps, n_hidden], for stacking multiple RNN after it. - - In a dynamic model, `return_seq_2d` can be updated when it is called in customised forward(). - By default, `False`. - return_last_state: boolean - Whether to return the last state of the RNN cell. The state is a list of Tensor. - For simple RNN and GRU, last_state = [last_output]; For LSTM, last_state = [last_output, last_cell_state] - - - If True, the layer will return outputs and the final state of the cell. - - If False, the layer will return outputs only. - - In a dynamic model, `return_last_state` can be updated when it is called in customised forward(). - By default, `False`. - in_channels: int - Optional, the number of channels of the previous layer which is normally the size of embedding. - If given, the layer will be built when init. - If None, it will be automatically detected when the layer is forwarded for the first time. - name : str - A unique layer name. Examples -------- - For synced sequence input and output, see `PTB example `__ + With TensorLayer - A simple regression model below. - - >>> inputs = tl.layers.Input([batch_size, num_steps, embedding_size]) - >>> rnn_out, lstm_state = tl.layers.RNN( - >>> cell=tf.keras.layers.LSTMCell(units=hidden_size, dropout=0.1), - >>> in_channels=embedding_size, - >>> return_last_output=True, return_last_state=True, name='lstmrnn' - >>> )(inputs) - >>> outputs = tl.layers.Dense(n_units=1)(rnn_out) - >>> rnn_model = tl.models.Model(inputs=inputs, outputs=[outputs, rnn_state[0], rnn_state[1]], name='rnn_model') - >>> # If LSTMCell is applied, the rnn_state is [h, c] where h the hidden state and c the cell state of LSTM. - - A stacked RNN model. - - >>> inputs = tl.layers.Input([batch_size, num_steps, embedding_size]) - >>> rnn_out1 = tl.layers.RNN( - >>> cell=tf.keras.layers.SimpleRNNCell(units=hidden_size, dropout=0.1), - >>> return_last_output=False, return_seq_2d=False, return_last_state=False - >>> )(inputs) - >>> rnn_out2 = tl.layers.RNN( - >>> cell=tf.keras.layers.SimpleRNNCell(units=hidden_size, dropout=0.1), - >>> return_last_output=True, return_last_state=False - >>> )(rnn_out1) - >>> outputs = tl.layers.Dense(n_units=1)(rnn_out2) - >>> rnn_model = tl.models.Model(inputs=inputs, outputs=outputs) - - An example if the sequences have different length and contain padding. - Similar to the DynamicRNN in TL 1.x. - - If the `sequence_length` is provided in RNN's forwarding and both `return_last_output` and `return_last_state` - are set as `True`, the forward function will automatically ignore the paddings. Note that if `return_last_output` - is set as `False`, the synced sequence outputs will still include outputs which correspond with paddings, - but users are free to select which slice of outputs to be used in following procedure. - - The `sequence_length` should be a list of integers which indicates the length of each sequence. - It is recommended to - `tl.layers.retrieve_seq_length_op3 `__ - to calculate the `sequence_length`. - - >>> data = [[[1], [2], [0], [0], [0]], [[1], [2], [3], [0], [0]], [[1], [2], [6], [1], [1]]] - >>> data = tf.convert_to_tensor(data, dtype=tf.float32) - >>> class DynamicRNNExample(tl.models.Model): - >>> def __init__(self): - >>> super(DynamicRNNExample, self).__init__() - >>> self.rnnlayer = tl.layers.RNN( - >>> cell=tf.keras.layers.SimpleRNNCell(units=6, dropout=0.1), in_channels=1, return_last_output=True, - >>> return_last_state=True - >>> ) - >>> def forward(self, x): - >>> z, s = self.rnnlayer(x, sequence_length=tl.layers.retrieve_seq_length_op3(x)) - >>> return z, s - >>> model = DynamicRNNExample() - >>> model.eval() - >>> output, state = model(data) - - - Notes - ----- - Input dimension should be rank 3 : [batch_size, n_steps, n_features], if no, please see layer :class:`Reshape`. + >>> input = tl.layers.Input([4, 16], name='input') + >>> prev_h = tl.layers.Input([4,32]) + >>> cell = tl.layers.RNNCell(input_size=16, hidden_size=32, bias=True, act='tanh', name='rnncell_1') + >>> y, h = cell(input, prev_h) + >>> print(y.shape) """ def __init__( self, - cell, - return_last_output=False, - return_seq_2d=False, - return_last_state=True, - in_channels=None, - name=None, # 'rnn' + input_size, + hidden_size, + bias=True, + act='tanh', + name=None, ): - - super(RNN, self).__init__(name=name) - - self.cell = cell - self.return_last_output = return_last_output - self.return_seq_2d = return_seq_2d - self.return_last_state = return_last_state - - if in_channels is not None: - self.build((None, None, in_channels)) - self._built = True - - logging.info("RNN %s: cell: %s, n_units: %s" % (self.name, self.cell.__class__.__name__, self.cell.units)) + super(RNNCell, self).__init__(name) + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + if act not in ('relu', 'tanh'): + raise ValueError("Activation should be 'tanh' or 'relu'.") + self.act = act + self.build(None) + logging.info("RNNCell %s: input_size: %d hidden_size: %d act: %s" % (self.name, input_size, hidden_size, act)) def __repr__(self): - s = ('{classname}(cell={cellname}, n_units={n_units}') - s += ', name=\'{name}\'' + actstr = self.act + s = ('{classname}(input_size={input_size}, hidden_size={hidden_size}') + s += ', bias=True' if self.bias else ', bias=False' + s += (',' + actstr) + if self.name is not None: + s += ', name=\'{name}\'' s += ')' - return s.format( - classname=self.__class__.__name__, cellname=self.cell.__class__.__name__, n_units=self.cell.units, - **self.__dict__ - ) + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def check_input(self, input_shape): + if input_shape[1] != self.input_size: + raise ValueError( + 'input should have consistent input_size. But got {}, expected {}'.format( + input_shape[1], self.input_size + ) + ) + + def check_hidden(self, input_shape, h_shape, hidden_label): + if input_shape[0] != h_shape[0]: + raise ValueError( + 'input batch size{} should match hidden{} batch size{}.'.format( + input_shape[0], hidden_label, h_shape[0] + ) + ) + if h_shape[1] != self.hidden_size: + raise ValueError( + 'hidden{} should have consistent hidden_size. But got {}, expected {}.'.format( + hidden_label, h_shape[1], self.hidden_size + ) + ) def build(self, inputs_shape): - """ - Parameters - ---------- - inputs_shape : tuple - the shape of inputs tensor - """ - # Input dimension should be rank 3 [batch_size, n_steps(max), n_features] - if len(inputs_shape) != 3: - raise Exception("RNN : Input dimension should be rank 3 : [batch_size, n_steps, n_features]") + stdv = 1.0 / np.sqrt(self.hidden_size) + _init = tl.initializers.RandomUniform(minval=-stdv, maxval=stdv) + self.weight_ih_shape = (self.hidden_size, self.input_size) + self.weight_hh_shape = (self.hidden_size, self.hidden_size) + self.weight_ih = self._get_weights("weight_ih", shape=self.weight_ih_shape, init=_init) + self.weight_hh = self._get_weights("weight_hh", shape=self.weight_hh_shape, init=_init) - with tf.name_scope(self.name) as scope: - self.cell.build(tuple(inputs_shape)) + if self.bias: + self.bias_ih_shape = (self.hidden_size, ) + self.bias_hh_shape = (self.hidden_size, ) + self.bias_ih = self._get_weights('bias_ih', shape=self.bias_ih_shape, init=_init) + self.bias_hh = self._get_weights('bias_hh', shape=self.bias_hh_shape, init=_init) + else: + self.bias_ih = None + self.bias_hh = None + self.rnncell = tl.ops.rnncell( + weight_ih=self.weight_ih, weight_hh=self.weight_hh, bias_ih=self.bias_ih, bias_hh=self.bias_hh, act=self.act + ) - if self._trainable_weights is None: - self._trainable_weights = list() - for var in self.cell.trainable_variables: - self._trainable_weights.append(var) + def forward(self, inputs, states=None): + input_shape = tl.get_tensor_shape(inputs) + self.check_input(input_shape) + if states is None: + states = tl.zeros(shape=(input_shape[0], self.hidden_size), dtype=inputs.dtype) + states_shape = tl.get_tensor_shape(states) + self.check_hidden(input_shape, states_shape, hidden_label='h') + output, states = self.rnncell(inputs, states) + return output, states - # @tf.function - def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs): - """ - Parameters - ---------- - inputs : input tensor - The input of a network - sequence_length: None or list of integers - The actual length of each sequence in batch without padding. - If provided, when `return_last_output` and `return_last_state` are `True`, - the RNN will perform in the manner of a dynamic RNN, i.e. - the RNN will return the actual last output / state without padding. - initial_state : None or list of Tensor (RNN State) - If None, `initial_state` is zero state. - **kwargs: dict - Some attributes can be updated during forwarding - such as `return_last_output`, `return_seq_2d`, `return_last_state`. - """ - if kwargs: - for attr in kwargs: - if attr in self.__dict__: - setattr(self, attr, kwargs[attr]) +class LSTMCell(Module): + """A long short-term memory (LSTM) cell. - batch_size = inputs.get_shape().as_list()[0] - total_steps = inputs.get_shape().as_list()[1] + Parameters + ---------- + input_size : int + The number of expected features in the input `x` + hidden_size : int + The number of features in the hidden state `h` + bias : bool + If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` + name : None or str + A unique layer name + -------------------------------------------------------- + inputs : tensor + A tensor with shape `[batch_size, input_size]`. + states : tuple or None + A tuple of two tensor `(h, c)`, each of shape `[batch_size, hidden_size]`. When states is None, zero state is used. Defaults: None. - # checking the type and values of sequence_length - if sequence_length is not None: - if isinstance(sequence_length, list): - pass - elif isinstance(sequence_length, tf.Tensor): - pass - elif isinstance(sequence_length, np.ndarray): - sequence_length = sequence_length.tolist() - else: - raise TypeError( - "The argument sequence_length should be either None or a list of integers. " - "Type got %s" % type(sequence_length) + Returns + ---------- + outputs : tensor + A tensor with shape `[batch_size, hidden_size]`. + states : tensor + A tuple of two tensor `(h, c)`, each of shape `[batch_size, hidden_size]`. + Tensors containing the next hidden state and next cell state for each element in the batch. + + + Examples + -------- + With TensorLayer + + >>> input = tl.layers.Input([4, 16], name='input') + >>> prev_h = tl.layers.Input([4,32]) + >>> prev_c = tl.layers.Input([4,32]) + >>> cell = tl.layers.LSTMCell(input_size=16, hidden_size=32, bias=True, name='lstmcell_1') + >>> y, (h, c)= cell(input, (prev_h, prev_c)) + >>> print(y.shape) + + """ + + def __init__( + self, + input_size, + hidden_size, + bias=True, + name=None, + ): + super(LSTMCell, self).__init__(name) + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.build(None) + logging.info("LSTMCell %s: input_size: %d hidden_size: %d " % (self.name, input_size, hidden_size)) + + def __repr__(self): + s = ('{classname}(input_size={input_size}, hidden_size={hidden_size}') + s += ', bias=True' if self.bias else ', bias=False' + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def check_input(self, input_shape): + if input_shape[1] != self.input_size: + raise ValueError( + 'input should have consistent input_size. But got {}, expected {}'.format( + input_shape[1], self.input_size ) - if (len(sequence_length) != batch_size): - raise ValueError( - "The argument sequence_length should contain %d " % batch_size + - "elements indicating the initial length of each sequence, but got only %d. " % len(sequence_length) + ) + + def check_hidden(self, input_shape, h_shape, hidden_label): + if input_shape[0] != h_shape[0]: + raise ValueError( + 'input batch size{} should match hidden{} batch size{}.'.format( + input_shape[0], hidden_label, h_shape[0] ) - for i in sequence_length: - if not (type(i) is int or (isinstance(i, tf.Tensor) and i.dtype.is_integer)): - raise TypeError( - "The argument sequence_length should be either None or a list of integers. " - "One element of sequence_length has the type %s" % type(i) - ) - if i > total_steps: - raise ValueError( - "The actual length of a sequence should not be longer than " - "that of the longest sequence (total steps) in this mini-batch. " - "Total steps of this mini-batch %d, " % total_steps + - "but got an actual length of a sequence %d" % i - ) + ) + if h_shape[1] != self.hidden_size: + raise ValueError( + 'hidden{} should have consistent hidden_size. But got {}, expected {}.'.format( + hidden_label, h_shape[1], self.hidden_size + ) + ) - sequence_length = tl.layers.retrieve_seq_length_op3(inputs) + def build(self, inputs_shape): + stdv = 1.0 / np.sqrt(self.hidden_size) + _init = tl.initializers.RandomUniform(minval=-stdv, maxval=stdv) + self.weight_ih_shape = (4 * self.hidden_size, self.input_size) + self.weight_hh_shape = (4 * self.hidden_size, self.hidden_size) + self.weight_ih = self._get_weights("weight_ih", shape=self.weight_ih_shape, init=_init) + self.weight_hh = self._get_weights("weight_hh", shape=self.weight_hh_shape, init=_init) - sequence_length = [i - 1 if i >= 1 else 0 for i in sequence_length] - - # set warning - # if (not self.return_last_output) and sequence_length is not None: - # warnings.warn( - # 'return_last_output is set as %s ' % self.return_last_output + - # 'When sequence_length is provided, it is recommended to set as True. ' + - # 'Otherwise, padding will be considered while RNN is forwarding.' - # ) - - # return the last output, iterating each seq including padding ones. No need to store output during each - # time step. - if self.return_last_output and sequence_length is None: - outputs = [-1] + if self.bias: + self.bias_ih_shape = (4 * self.hidden_size, ) + self.bias_hh_shape = (4 * self.hidden_size, ) + self.bias_ih = self._get_weights('bias_ih', shape=self.bias_ih_shape, init=_init) + self.bias_hh = self._get_weights('bias_hh', shape=self.bias_hh_shape, init=_init) else: - outputs = list() + self.bias_ih = None + self.bias_hh = None - # initialize the states if provided - states = initial_state if initial_state is not None else self.cell.get_initial_state(inputs) - if not isinstance(states, list): - states = [states] - - stored_states = list() - - # initialize the cell - self.cell.reset_dropout_mask() - self.cell.reset_recurrent_dropout_mask() - - # recurrent computation - # FIXME: if sequence_length is provided (dynamic rnn), only iterate max(sequence_length) times. - for time_step in range(total_steps): - - cell_output, states = self.cell.call(inputs[:, time_step, :], states, training=self.is_train) - stored_states.append(states) - - if self.return_last_output and sequence_length is None: - outputs[-1] = cell_output - else: - outputs.append(cell_output) - - # prepare to return results - if self.return_last_output and sequence_length is None: - outputs = outputs[-1] - - elif self.return_last_output and sequence_length is not None: - outputs = tf.convert_to_tensor(outputs) - outputs = tf.gather(outputs, sequence_length, axis=0) - - outputs_without_padding = [] - for i in range(batch_size): - outputs_without_padding.append(outputs[i][i][:]) - outputs = tf.convert_to_tensor(outputs_without_padding) - else: - if self.return_seq_2d: - # PTB tutorial: stack dense layer after that, or compute the cost from the output - # 2D Tensor [batch_size * n_steps, n_hidden] - outputs = tf.reshape(tf.concat(outputs, 1), [-1, self.cell.units]) - else: - # : stack more RNN layer after that - # 3D Tensor [batch_size, n_steps, n_hidden] - outputs = tf.reshape(tf.concat(outputs, 1), [-1, total_steps, self.cell.units]) - - if self.return_last_state and sequence_length is None: - return outputs, states - elif self.return_last_state and sequence_length is not None: - - stored_states = tf.convert_to_tensor(stored_states) - stored_states = tf.gather(stored_states, sequence_length, axis=0) - - states = [] - for i in range(stored_states.shape[1]): - states.append(tf.convert_to_tensor([stored_states[b, i, b, :] for b in range(batch_size)])) - - return outputs, states - else: - return outputs - - -class SimpleRNN(RNN): - """ - The :class:`SimpleRNN` class is a fixed length recurrent layer for implementing simple RNN. - - Parameters - ---------- - units: int - Positive integer, the dimension of hidden space. - return_last_output : boolean - Whether return last output or all outputs in a sequence. - - If True, return the last output, "Sequence input and single output" - - If False, return all outputs, "Synced sequence input and output" - - In other word, if you want to stack more RNNs on this layer, set to False - - In a dynamic model, `return_last_output` can be updated when it is called in customised forward(). - By default, `False`. - return_seq_2d : boolean - Only consider this argument when `return_last_output` is `False` - - If True, return 2D Tensor [batch_size * n_steps, n_hidden], for stacking Dense layer after it. - - If False, return 3D Tensor [batch_size, n_steps, n_hidden], for stacking multiple RNN after it. - - In a dynamic model, `return_seq_2d` can be updated when it is called in customised forward(). - By default, `False`. - return_last_state: boolean - Whether to return the last state of the RNN cell. The state is a list of Tensor. - For simple RNN, last_state = [last_output] - - - If True, the layer will return outputs and the final state of the cell. - - If False, the layer will return outputs only. - - In a dynamic model, `return_last_state` can be updated when it is called in customised forward(). - By default, `False`. - in_channels: int - Optional, the number of channels of the previous layer which is normally the size of embedding. - If given, the layer will be built when init. - If None, it will be automatically detected when the layer is forwarded for the first time. - name : str - A unique layer name. - `**kwargs`: - Advanced arguments to configure the simple RNN cell. - Please check tf.keras.layers.SimpleRNNCell. - - Examples - -------- - - A simple regression model below. - - >>> inputs = tl.layers.Input([batch_size, num_steps, embedding_size]) - >>> rnn_out, lstm_state = tl.layers.SimpleRNN( - >>> units=hidden_size, dropout=0.1, # both units and dropout are used to configure the simple rnn cell. - >>> in_channels=embedding_size, - >>> return_last_output=True, return_last_state=True, name='simplernn' - >>> )(inputs) - >>> outputs = tl.layers.Dense(n_units=1)(rnn_out) - >>> rnn_model = tl.models.Model(inputs=inputs, outputs=[outputs, rnn_state[0]], name='rnn_model') - - Notes - ----- - Input dimension should be rank 3 : [batch_size, n_steps, n_features], if no, please see layer :class:`Reshape`. - - """ - - def __init__( - self, - units, - return_last_output=False, - return_seq_2d=False, - return_last_state=True, - in_channels=None, - name=None, # 'simplernn' - **kwargs - ): - super(SimpleRNN, self).__init__( - cell=tf.keras.layers.SimpleRNNCell(units=units, **kwargs), return_last_output=return_last_output, - return_seq_2d=return_seq_2d, return_last_state=return_last_state, in_channels=in_channels, name=name + self.lstmcell = tl.ops.lstmcell( + weight_ih=self.weight_ih, weight_hh=self.weight_hh, bias_ih=self.bias_ih, bias_hh=self.bias_hh ) + def forward(self, inputs, states=None): + input_shape = tl.get_tensor_shape(inputs) + self.check_input(input_shape) + if states is not None: + h, c = states + else: + h = tl.zeros(shape=(input_shape[0], self.hidden_size), dtype=inputs.dtype) + c = tl.zeros(shape=(input_shape[0], self.hidden_size), dtype=inputs.dtype) + h_shape = tl.get_tensor_shape(h) + c_shape = tl.get_tensor_shape(c) + self.check_hidden(input_shape, h_shape, hidden_label='h') + self.check_hidden(input_shape, c_shape, hidden_label='c') + output, new_h, new_c = self.lstmcell(inputs, h, c) + return output, (new_h, new_c) -class GRURNN(RNN): - """ - The :class:`GRURNN` class is a fixed length recurrent layer for implementing RNN with GRU cell. + +class GRUCell(Module): + """A gated recurrent unit (GRU) cell. Parameters ---------- - units: int - Positive integer, the dimension of hidden space. - return_last_output : boolean - Whether return last output or all outputs in a sequence. - - If True, return the last output, "Sequence input and single output" - - If False, return all outputs, "Synced sequence input and output" - - In other word, if you want to stack more RNNs on this layer, set to False + input_size : int + The number of expected features in the input `x` + hidden_size : int + The number of features in the hidden state `h` + bias : bool + If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` + name : None or str + A unique layer name + -------------------------------------------------------- + inputs : tensor + A tensor with shape `[batch_size, input_size]`. + states : tensor or None + A tensor with shape `[batch_size, hidden_size]`. When states is None, zero state is used. Defaults: `None`. - In a dynamic model, `return_last_output` can be updated when it is called in customised forward(). - By default, `False`. - return_seq_2d : boolean - Only consider this argument when `return_last_output` is `False` - - If True, return 2D Tensor [batch_size * n_steps, n_hidden], for stacking Dense layer after it. - - If False, return 3D Tensor [batch_size, n_steps, n_hidden], for stacking multiple RNN after it. + Returns + ---------- + outputs : tensor + A tensor with shape `[batch_size, hidden_size]`. + states : tensor + A tensor with shape `[batch_size, hidden_size]`. + Tensor containing the next hidden state for each element in the batch - In a dynamic model, `return_seq_2d` can be updated when it is called in customised forward(). - By default, `False`. - return_last_state: boolean - Whether to return the last state of the RNN cell. The state is a list of Tensor. - For GRU, last_state = [last_output] - - - If True, the layer will return outputs and the final state of the cell. - - If False, the layer will return outputs only. - - In a dynamic model, `return_last_state` can be updated when it is called in customised forward(). - By default, `False`. - in_channels: int - Optional, the number of channels of the previous layer which is normally the size of embedding. - If given, the layer will be built when init. - If None, it will be automatically detected when the layer is forwarded for the first time. - name : str - A unique layer name. - `**kwargs`: - Advanced arguments to configure the GRU cell. - Please check tf.keras.layers.GRUCell. Examples -------- + With TensorLayer - A simple regression model below. - - >>> inputs = tl.layers.Input([batch_size, num_steps, embedding_size]) - >>> rnn_out, lstm_state = tl.layers.GRURNN( - >>> units=hidden_size, dropout=0.1, # both units and dropout are used to configure the GRU cell. - >>> in_channels=embedding_size, - >>> return_last_output=True, return_last_state=True, name='grurnn' - >>> )(inputs) - >>> outputs = tl.layers.Dense(n_units=1)(rnn_out) - >>> rnn_model = tl.models.Model(inputs=inputs, outputs=[outputs, rnn_state[0]], name='rnn_model') - - Notes - ----- - Input dimension should be rank 3 : [batch_size, n_steps, n_features], if no, please see layer :class:`Reshape`. + >>> input = tl.layers.Input([4, 16], name='input') + >>> prev_h = tl.layers.Input([4,32]) + >>> cell = tl.layers.GRUCell(input_size=16, hidden_size=32, bias=True, name='grucell_1') + >>> y, h= cell(input, prev_h) + >>> print(y.shape) """ def __init__( self, - units, - return_last_output=False, - return_seq_2d=False, - return_last_state=True, - in_channels=None, - name=None, # 'grurnn' - **kwargs + input_size, + hidden_size, + bias=True, + name=None, ): - super(GRURNN, self).__init__( - cell=tf.keras.layers.GRUCell(units=units, **kwargs), return_last_output=return_last_output, - return_seq_2d=return_seq_2d, return_last_state=return_last_state, in_channels=in_channels, name=name + super(GRUCell, self).__init__(name) + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.build(None) + logging.info("GRUCell %s: input_size: %d hidden_size: %d " % (self.name, input_size, hidden_size)) + + def __repr__(self): + s = ('{classname}(input_size={input_size}, hidden_size={hidden_size}') + s += ', bias=True' if self.bias else ', bias=False' + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def check_input(self, input_shape): + if input_shape[1] != self.input_size: + raise ValueError( + 'input should have consistent input_size. But got {}, expected {}'.format( + input_shape[1], self.input_size + ) + ) + + def check_hidden(self, input_shape, h_shape, hidden_label): + if input_shape[0] != h_shape[0]: + raise ValueError( + 'input batch size{} should match hidden{} batch size{}.'.format( + input_shape[0], hidden_label, h_shape[0] + ) + ) + if h_shape[1] != self.hidden_size: + raise ValueError( + 'hidden{} should have consistent hidden_size. But got {}, expected {}.'.format( + hidden_label, h_shape[1], self.hidden_size + ) + ) + + def build(self, inputs_shape): + stdv = 1.0 / np.sqrt(self.hidden_size) + _init = tl.initializers.RandomUniform(minval=-stdv, maxval=stdv) + self.weight_ih_shape = (3 * self.hidden_size, self.input_size) + self.weight_hh_shape = (3 * self.hidden_size, self.hidden_size) + self.weight_ih = self._get_weights("weight_ih", shape=self.weight_ih_shape, init=_init) + self.weight_hh = self._get_weights("weight_hh", shape=self.weight_hh_shape, init=_init) + + if self.bias: + self.bias_ih_shape = (3 * self.hidden_size, ) + self.bias_hh_shape = (3 * self.hidden_size, ) + self.bias_ih = self._get_weights('bias_ih', shape=self.bias_ih_shape, init=_init) + self.bias_hh = self._get_weights('bias_hh', shape=self.bias_hh_shape, init=_init) + else: + self.bias_ih = None + self.bias_hh = None + + self.grucell = tl.ops.grucell( + weight_ih=self.weight_ih, weight_hh=self.weight_hh, bias_ih=self.bias_ih, bias_hh=self.bias_hh ) + def forward(self, inputs, states=None): + input_shape = tl.get_tensor_shape(inputs) + self.check_input(input_shape) + if states is None: + states = tl.zeros(shape=(input_shape[0], self.hidden_size), dtype=inputs.dtype) + states_shape = tl.get_tensor_shape(states) + self.check_hidden(input_shape, states_shape, hidden_label='h') + output, states = self.grucell(inputs, states) + return output, states -class LSTMRNN(RNN): + +class RNNBase(Module): """ - The :class:`LSTMRNN` class is a fixed length recurrent layer for implementing RNN with LSTM cell. - - Parameters - ---------- - units: int - Positive integer, the dimension of hidden space. - return_last_output : boolean - Whether return last output or all outputs in a sequence. - - If True, return the last output, "Sequence input and single output" - - If False, return all outputs, "Synced sequence input and output" - - In other word, if you want to stack more RNNs on this layer, set to False - - In a dynamic model, `return_last_output` can be updated when it is called in customised forward(). - By default, `False`. - return_seq_2d : boolean - Only consider this argument when `return_last_output` is `False` - - If True, return 2D Tensor [batch_size * n_steps, n_hidden], for stacking Dense layer after it. - - If False, return 3D Tensor [batch_size, n_steps, n_hidden], for stacking multiple RNN after it. - - In a dynamic model, `return_seq_2d` can be updated when it is called in customised forward(). - By default, `False`. - return_last_state: boolean - Whether to return the last state of the RNN cell. The state is a list of Tensor. - For LSTM, last_state = [last_output, last_cell_state] - - - If True, the layer will return outputs and the final state of the cell. - - If False, the layer will return outputs only. - - In a dynamic model, `return_last_state` can be updated when it is called in customised forward(). - By default, `False`. - in_channels: int - Optional, the number of channels of the previous layer which is normally the size of embedding. - If given, the layer will be built when init. - If None, it will be automatically detected when the layer is forwarded for the first time. - name : str - A unique layer name. - `**kwargs`: - Advanced arguments to configure the LSTM cell. - Please check tf.keras.layers.LSTMCell. - - Examples - -------- - - A simple regression model below. - - >>> inputs = tl.layers.Input([batch_size, num_steps, embedding_size]) - >>> rnn_out, lstm_state = tl.layers.LSTMRNN( - >>> units=hidden_size, dropout=0.1, # both units and dropout are used to configure the LSTM cell. - >>> in_channels=embedding_size, - >>> return_last_output=True, return_last_state=True, name='grurnn' - >>> )(inputs) - >>> outputs = tl.layers.Dense(n_units=1)(rnn_out) - >>> rnn_model = tl.models.Model(inputs=inputs, outputs=[outputs, rnn_state[0]], name='rnn_model') - - Notes - ----- - Input dimension should be rank 3 : [batch_size, n_steps, n_features], if no, please see layer :class:`Reshape`. - + RNNBase class for RNN networks. It provides `forward` and other common methods for RNN, LSTM and GRU. """ def __init__( self, - units, - return_last_output=False, - return_seq_2d=False, - return_last_state=True, - in_channels=None, - name=None, # 'lstmrnn' - **kwargs + mode, + input_size, + hidden_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0.0, + bidirectional=False, + name=None, ): - super(LSTMRNN, self).__init__( - cell=tf.keras.layers.LSTMCell(units=units, **kwargs), return_last_output=return_last_output, - return_seq_2d=return_seq_2d, return_last_state=return_last_state, in_channels=in_channels, name=name - ) - - -class BiRNN(Module): - """ - The :class:`BiRNN` class is a fixed length Bidirectional recurrent layer. - - Parameters - ---------- - fw_cell : TensorFlow cell function for forward direction - A RNN cell implemented by tf.keras, e.g. tf.keras.layers.SimpleRNNCell, tf.keras.layers.LSTMCell, tf.keras.layers.GRUCell. - Note TF2.0+, TF1.0+ and TF1.0- are different - bw_cell: TensorFlow cell function for backward direction similar with `fw_cell` - return_seq_2d : boolean. - If True, return 2D Tensor [batch_size * n_steps, n_hidden], for stacking Dense layer after it. - If False, return 3D Tensor [batch_size, n_steps, n_hidden], for stacking multiple RNN after it. - In a dynamic model, `return_seq_2d` can be updated when it is called in customised forward(). - By default, `False`. - return_last_state: boolean - Whether to return the last state of the two cells. The state is a list of Tensor. - - If True, the layer will return outputs, the final state of `fw_cell` and the final state of `bw_cell`. - - If False, the layer will return outputs only. - - In a dynamic model, `return_last_state` can be updated when it is called in customised forward(). - By default, `False`. - in_channels: int - Optional, the number of channels of the previous layer which is normally the size of embedding. - If given, the layer will be built when init. - If None, it will be automatically detected when the layer is forwarded for the first time. - name : str - A unique layer name. - - Examples - -------- - A simple regression model below. - - >>> inputs = tl.layers.Input([batch_size, num_steps, embedding_size]) - >>> # the fw_cell and bw_cell can be different - >>> rnnlayer = tl.layers.BiRNN( - >>> fw_cell=tf.keras.layers.SimpleRNNCell(units=hidden_size, dropout=0.1), - >>> bw_cell=tf.keras.layers.SimpleRNNCell(units=hidden_size + 1, dropout=0.1), - >>> return_seq_2d=True, return_last_state=True - >>> ) - >>> # if return_last_state=True, the final state of the two cells will be returned together with the outputs - >>> # if return_last_state=False, only the outputs will be returned - >>> rnn_out, rnn_fw_state, rnn_bw_state = rnnlayer(inputs) - >>> # if the BiRNN is followed by a Dense, return_seq_2d should be True. - >>> # if the BiRNN is followed by other RNN, return_seq_2d can be False. - >>> dense = tl.layers.Dense(n_units=1)(rnn_out) - >>> outputs = tl.layers.Reshape([-1, num_steps])(dense) - >>> rnn_model = tl.models.Model(inputs=inputs, outputs=[outputs, rnn_out, rnn_fw_state[0], rnn_bw_state[0]]) - - A stacked BiRNN model. - - >>> inputs = tl.layers.Input([batch_size, num_steps, embedding_size]) - >>> rnn_out1 = tl.layers.BiRNN( - >>> fw_cell=tf.keras.layers.SimpleRNNCell(units=hidden_size, dropout=0.1), - >>> bw_cell=tf.keras.layers.SimpleRNNCell(units=hidden_size + 1, dropout=0.1), - >>> return_seq_2d=False, return_last_state=False - >>> )(inputs) - >>> rnn_out2 = tl.layers.BiRNN( - >>> fw_cell=tf.keras.layers.SimpleRNNCell(units=hidden_size, dropout=0.1), - >>> bw_cell=tf.keras.layers.SimpleRNNCell(units=hidden_size + 1, dropout=0.1), - >>> return_seq_2d=True, return_last_state=False - >>> )(rnn_out1) - >>> dense = tl.layers.Dense(n_units=1)(rnn_out2) - >>> outputs = tl.layers.Reshape([-1, num_steps])(dense) - >>> rnn_model = tl.models.Model(inputs=inputs, outputs=outputs) - - Notes - ----- - Input dimension should be rank 3 : [batch_size, n_steps, n_features]. If not, please see layer :class:`Reshape`. - - """ - - def __init__( - self, - fw_cell, - bw_cell, - return_seq_2d=False, - return_last_state=False, - in_channels=None, - name=None, # 'birnn' - ): - super(BiRNN, self).__init__(name) - - self.fw_cell = fw_cell - self.bw_cell = bw_cell - self.return_seq_2d = return_seq_2d - self.return_last_state = return_last_state - - if in_channels is not None: - self.build((None, None, in_channels)) - self._built = True + super(RNNBase, self).__init__(name) + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = dropout + self.bidirectional = bidirectional + self.build(None) logging.info( - "BiRNN %s: fw_cell: %s, fw_n_units: %s, bw_cell: %s, bw_n_units: %s" % ( - self.name, self.fw_cell.__class__.__name__, self.fw_cell.units, self.bw_cell.__class__.__name__, - self.bw_cell.units - ) + "%s: %s: input_size: %d hidden_size: %d num_layers: %d " % + (self.mode, self.name, input_size, hidden_size, num_layers) ) def __repr__(self): s = ( - '{classname}(fw_cell={fw_cellname}, fw_n_units={fw_n_units}' - ', bw_cell={bw_cellname}, bw_n_units={bw_n_units}' + '{classname}(input_size={input_size}, hidden_size={hidden_size}, num_layers={num_layers}' + ', dropout={dropout}' ) - s += ', name=\'{name}\'' + s += ', bias=True' if self.bias else ', bias=False' + s += ', bidirectional=True' if self.bidirectional else ', bidirectional=False' + if self.name is not None: + s += ', name=\'{name}\'' s += ')' - return s.format( - classname=self.__class__.__name__, fw_cellname=self.fw_cell.__class__.__name__, - fw_n_units=self.fw_cell.units, bw_cellname=self.bw_cell.__class__.__name__, bw_n_units=self.bw_cell.units, - **self.__dict__ - ) + return s.format(classname=self.__class__.__name__, **self.__dict__) def build(self, inputs_shape): - """ - Parameters - ---------- - inputs_shape : tuple - the shape of inputs tensor - """ - # Input dimension should be rank 3 [batch_size, n_steps(max), n_features] - if len(inputs_shape) != 3: - raise Exception("RNN : Input dimension should be rank 3 : [batch_size, n_steps, n_features]") + if BACKEND == 'tensorflow': + bidirect = 2 if self.bidirectional else 1 + self.weights_fw = [] + self.bias_fw = [] + self.weights_bw = [] + self.bias_bw = [] + stdv = 1.0 / np.sqrt(self.hidden_size) + _init = tl.initializers.RandomUniform(minval=-stdv, maxval=stdv) + if self.mode == 'LSTM': + gate_size = 4 * self.hidden_size + elif self.mode == 'GRU': + gate_size = 3 * self.hidden_size + else: + gate_size = self.hidden_size + for layer in range(self.num_layers): + for direction in range(bidirect): + layer_input_size = self.input_size if layer == 0 else self.hidden_size * bidirect + if direction == 0: + self.w_ih = self._get_weights( + 'weight_ih_l' + str(layer), shape=(gate_size, layer_input_size), init=_init + ) + self.w_hh = self._get_weights( + 'weight_ih_l' + str(layer), shape=(gate_size, self.hidden_size), init=_init + ) + self.weights_fw.append(self.w_ih) + self.weights_fw.append(self.w_hh) + if self.bias: + self.b_ih = self._get_weights('bias_ih_l' + str(layer), shape=(gate_size, ), init=_init) + self.b_hh = self._get_weights('bias_hh_l' + str(layer), shape=(gate_size, ), init=_init) + self.bias_fw.append(self.b_ih) + self.bias_fw.append(self.b_hh) + else: + self.w_ih = self._get_weights( + 'weight_ih_l' + str(layer) + '_reverse', shape=(gate_size, layer_input_size), init=_init + ) + self.w_hh = self._get_weights( + 'weight_hh_l' + str(layer) + '_reverse', shape=(gate_size, self.hidden_size), init=_init + ) + self.weights_bw.append(self.w_ih) + self.weights_bw.append(self.w_hh) + if self.bias: + self.b_ih = self._get_weights( + 'bias_ih_l' + str(layer) + '_reverse', shape=(gate_size, ), init=_init + ) + self.b_hh = self._get_weights( + 'bias_hh_l' + str(layer) + '_reverse', shape=(gate_size, ), init=_init + ) + self.bias_bw.append(self.b_ih) + self.bias_bw.append(self.b_hh) - with tf.name_scope(self.name) as scope: - self.fw_cell.build(tuple(inputs_shape)) - self.bw_cell.build(tuple(inputs_shape)) - - if self._trainable_weights is None: - self._trainable_weights = list() - for var in self.fw_cell.trainable_variables: - self._trainable_weights.append(var) - for var in self.bw_cell.trainable_variables: - self._trainable_weights.append(var) - - # @tf.function - def forward(self, inputs, fw_initial_state=None, bw_initial_state=None, **kwargs): - """ - Parameters - ---------- - inputs : input tensor - The input of a network - fw_initial_state : None or list of Tensor (RNN State) - If None, `fw_initial_state` is zero state. - bw_initial_state : None or list of Tensor (RNN State) - If None, `bw_initial_state` is zero state. - **kwargs: dict - Some attributes can be updated during forwarding - such as `return_last_output`, `return_seq_2d`, `return_last_state`. - """ - - if kwargs: - for attr in kwargs: - if attr in self.__dict__: - setattr(self, attr, kwargs[attr]) - - fw_outputs = list() - bw_outputs = list() - - fw_states = fw_initial_state if fw_initial_state is not None else self.fw_cell.get_initial_state(inputs) - bw_states = bw_initial_state if bw_initial_state is not None else self.bw_cell.get_initial_state(inputs) - - if not isinstance(fw_states, list): - fw_states = [fw_states] - if not isinstance(bw_states, list): - bw_states = [bw_states] - - total_steps = inputs.get_shape().as_list()[1] - - self.fw_cell.reset_dropout_mask() - self.fw_cell.reset_recurrent_dropout_mask() - self.bw_cell.reset_dropout_mask() - self.bw_cell.reset_recurrent_dropout_mask() - - for time_step in range(total_steps): - fw_cell_output, fw_states = self.fw_cell.call(inputs[:, time_step, :], fw_states, training=self.is_train) - bw_cell_output, bw_states = self.bw_cell.call( - inputs[:, -time_step - 1, :], bw_states, training=self.is_train + self.rnn = tl.ops.rnnbase( + mode=self.mode, input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers, + bias=self.bias, batch_first=self.batch_first, dropout=self.dropout, bidirectional=self.bidirectional, + is_train=self.is_train, weights_fw=self.weights_fw, weights_bw=self.weights_bw, bias_fw=self.bias_fw, + bias_bw=self.bias_bw + ) + else: + self.rnn = tl.ops.rnnbase( + mode=self.mode, + input_size=self.input_size, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + bias=self.bias, + batch_first=self.batch_first, + dropout=self.dropout, + bidirectional=self.bidirectional, + is_train=self.is_train, ) - fw_outputs.append(fw_cell_output) - bw_outputs.append(bw_cell_output) + def forward(self, input, states=None): - if self.return_seq_2d: - # PTB tutorial: stack dense layer after that, or compute the cost from the output - # 2D Tensor [batch_size * n_steps, n_hidden] - fw_outputs = tf.reshape(tf.concat(fw_outputs, 1), [-1, self.fw_cell.units]) - bw_outputs = tf.reshape(tf.concat(bw_outputs, 1), [-1, self.bw_cell.units]) - else: - # : stack more RNN layer after that - # 3D Tensor [batch_size, n_steps, n_hidden] - fw_outputs = tf.reshape(tf.concat(fw_outputs, 1), [-1, total_steps, self.fw_cell.units]) - bw_outputs = tf.reshape(tf.concat(bw_outputs, 1), [-1, total_steps, self.bw_cell.units]) - - outputs = tf.concat([fw_outputs, bw_outputs], -1) - - if self.return_last_state: - return outputs, fw_states, bw_states - else: - return outputs + output, new_states = self.rnn(input, states) + return output, new_states -''' -class ConvRNNCell(object): - """Abstract object representing an Convolutional RNN Cell.""" - - def __call__(self, inputs, state, scope=None): - """Run this RNN cell on inputs, starting from the given state.""" - raise NotImplementedError("Abstract method") - - @property - def state_size(self): - """size(s) of state(s) used by this cell.""" - raise NotImplementedError("Abstract method") - - @property - def output_size(self): - """Integer or TensorShape: size of outputs produced by this cell.""" - raise NotImplementedError("Abstract method") - - def zero_state(self, batch_size): #, dtype=LayersConfig.tf_dtype): - """Return zero-filled state tensor(s). - Args: - batch_size: int, float, or unit Tensor representing the batch size. - Returns: - tensor of shape '[batch_size x shape[0] x shape[1] x num_features] - filled with zeros - - """ - dtype = LayersConfig.tf_dtype - shape = self.shape - num_features = self.num_features - # TODO : TypeError: 'NoneType' object is not subscriptable - zeros = tf.zeros([batch_size, shape[0], shape[1], num_features * 2], dtype=dtype) - return zeros - - -class BasicConvLSTMCell(ConvRNNCell): - """Basic Conv LSTM recurrent network cell. - - Parameters - ----------- - shape : tuple of int - The height and width of the cell. - filter_size : tuple of int - The height and width of the filter - num_features : int - The hidden size of the cell - forget_bias : float - The bias added to forget gates (see above). - input_size : int - Deprecated and unused. - state_is_tuple : boolen - If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. - If False, they are concatenated along the column axis. The latter behavior will soon be deprecated. - act : activation function - The activation function of this layer, tanh as default. - - """ - - def __init__( - self, shape, filter_size, num_features, forget_bias=1.0, input_size=None, state_is_tuple=False, - act=tf.nn.tanh - ): - """Initialize the basic Conv LSTM cell.""" - # if not state_is_tuple: - # logging.warn("%s: Using a concatenated state is slower and will soon be " - # "deprecated. Use state_is_tuple=True.", self) - if input_size is not None: - logging.warn("%s: The input_size parameter is deprecated.", self) - self.shape = shape - self.filter_size = filter_size - self.num_features = num_features - self._forget_bias = forget_bias - self._state_is_tuple = state_is_tuple - self._activation = act - - @property - def state_size(self): - """State size of the LSTMStateTuple.""" - return (LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units) - - @property - def output_size(self): - """Number of units in outputs.""" - return self._num_units - - def __call__(self, inputs, state, scope=None): - """Long short-term memory cell (LSTM).""" - with tf.compat.v1.variable_scope(scope or type(self).__name__): # "BasicLSTMCell" - # Parameters of gates are concatenated into one multiply for efficiency. - if self._state_is_tuple: - c, h = state - else: - # print state - # c, h = tf.split(3, 2, state) - c, h = tf.split(state, 2, 3) - concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True) - - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - # i, j, f, o = tf.split(3, 4, concat) - i, j, f, o = tf.split(concat, 4, 3) - - new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * self._activation(j)) - new_h = self._activation(new_c) * tf.nn.sigmoid(o) - - if self._state_is_tuple: - new_state = LSTMStateTuple(new_c, new_h) - else: - new_state = tf.concat([new_c, new_h], 3) - return new_h, new_state - - -def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=None): - """convolution: +class RNN(RNNBase): + """Multilayer Elman network(RNN). It takes input sequences and initial + states as inputs, and returns the output sequences and the final states. Parameters ---------- - args : tensor - 4D Tensor or a list of 4D, batch x n, Tensors. - filter_size : tuple of int - Filter height and width. - num_features : int - Nnumber of features. - bias_start : float - Starting value to initialize the bias; 0 by default. - scope : VariableScope - For the created subgraph; defaults to "Linear". + input_size : int + The number of expected features in the input `x` + hidden_size : int + The number of features in the hidden state `h` + num_layers : int + Number of recurrent layers. Default: 1 + bias : bool + If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` + batch_first : bool + If ``True``, then the input and output tensors are provided as `[batch_size, seq, input_size]`, Default: ``False`` + dropout : float + If non-zero, introduces a `Dropout` layer on the outputs of each RNN layer except the last layer, + with dropout probability equal to `dropout`. Default: 0 + bidirectional : bool + If ``True``, becomes a bidirectional RNN. Default: ``False`` + act : activation function + The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh' + name : None or str + A unique layer name + -------------------------------------------------------- + inputs : tensor + the input sequence. if `batch_first` is True, the shape is `[batch_size, seq, input_size]`, else, the shape is `[seq, batch_size, input_size]`. + initial_states : tensor or None + the initial states. The shape is `[num_layers * num_directions, batch_size, hidden_size]`.If initial_state is not given, zero initial states are used. + If the RNN is Bidirectional, num_directions should be 2, else it should be 1. Default: None. Returns - -------- - - A 4D Tensor with shape [batch h w num_features] - - Raises - ------- - - ValueError : if some of the arguments has unspecified or wrong shape. - - """ - # Calculate the total size of arguments on dimension 1. - total_arg_size_depth = 0 - shapes = [a.get_shape().as_list() for a in args] - for shape in shapes: - if len(shape) != 4: - raise ValueError("Linear is expecting 4D arguments: %s" % str(shapes)) - if not shape[3]: - raise ValueError("Linear expects shape[4] of arguments: %s" % str(shapes)) - else: - total_arg_size_depth += shape[3] - - dtype = [a.dtype for a in args][0] - - # Now the computation. - with tf.compat.v1.variable_scope(scope or "Conv"): - matrix = tf.compat.v1.get_variable( - "Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype=dtype - ) - if len(args) == 1: - res = tf.nn.conv2d(args[0], matrix, strides=[1, 1, 1, 1], padding='SAME') - else: - res = tf.nn.conv2d(tf.concat(args, 3), matrix, strides=[1, 1, 1, 1], padding='SAME') - if not bias: - return res - bias_term = tf.compat.v1.get_variable( - "Bias", [num_features], dtype=dtype, - initializer=tf.compat.v1.initializers.constant(bias_start, dtype=dtype) - ) - return res + bias_term - - -class ConvLSTM(Module): - """A fixed length Convolutional LSTM layer. - - See this `paper `__ . - - Parameters - ---------- - prev_layer : :class:`Module` - Previous layer - cell_shape : tuple of int - The shape of each cell width * height - filter_size : tuple of int - The size of filter width * height - cell_fn : a convolutional RNN cell - Cell function like :class:`BasicConvLSTMCell` - feature_map : int - The number of feature map in the layer. - initializer : initializer - The initializer for initializing the parameters. - n_steps : int - The sequence length. - initial_state : None or ConvLSTM State - If None, `initial_state` is zero state. - return_last : boolean - Whether return last output or all outputs in each step. - - If True, return the last output, "Sequence input and single output". - - If False, return all outputs, "Synced sequence input and output". - - In other word, if you want to stack more RNNs on this layer, set to False. - - return_seq_2d : boolean - Only consider this argument when `return_last_output` is `False` - - If True, return 2D Tensor [n_example, n_hidden], for stacking DenseLayer after it. - - If False, return 3D Tensor [n_example/n_steps, n_steps, n_hidden], for stacking multiple RNN after it. - - name : str - A unique layer name. - - Attributes ---------- outputs : tensor - The output of this RNN. return_last_output = False, outputs = all cell_output, which is the hidden state. - cell_output.get_shape() = (?, h, w, c]) + the output sequence. if `batch_first` is True, the shape is `[batch_size, seq, num_directions * hidden_size]`, + else, the shape is `[seq, batch_size, num_directions * hidden_size]`. + final_states : tensor + final states. The shape is `[num_layers * num_directions, batch_size, hidden_size]`. Note that if the RNN is Bidirectional, the forward states are (0,2,4,6,...) and + the backward states are (1,3,5,7,....). - final_state : tensor or StateTuple - The finial state of this layer. - - When state_is_tuple = False, it is the final hidden and cell states, - - When state_is_tuple = True, You can get the final state after each iteration during training, then feed it to the initial state of next iteration. + Examples + -------- + With TensorLayer - initial_state : tensor or StateTuple - It is the initial state of this ConvLSTM layer, you can use it to initialize - your state at the beginning of each epoch or iteration according to your - training procedure. - - batch_size : int or tensor - Is int, if able to compute the batch_size, otherwise, tensor for ``?``. + >>> input = tl.layers.Input([23, 32, 16], name='input') + >>> prev_h = tl.layers.Input([4, 32, 32]) + >>> cell = tl.layers.RNN(input_size=16, hidden_size=32, bias=True, num_layers=2, bidirectional = True, act='tanh', batch_first=False, dropout=0, name='rnn_1') + >>> y, h= cell(input, prev_h) + >>> print(y.shape) """ - @deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release def __init__( - self, - prev_layer, - cell_shape=None, - feature_map=1, - filter_size=(3, 3), - cell_fn=BasicConvLSTMCell, - initializer=tf.compat.v1.initializers.random_uniform(-0.1, 0.1), - n_steps=5, - initial_state=None, - return_last=False, - return_seq_2d=False, - name='convlstm', + self, + input_size, + hidden_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0.0, + bidirectional=False, + act='tanh', + name=None, ): - super(ConvLSTM, self).__init__(prev_layer=prev_layer, name=name) - - logging.info( - "ConvLSTM %s: feature_map: %d, n_steps: %d, " - "in_dim: %d %s, cell_fn: %s " % - (self.name, feature_map, n_steps, self.inputs.get_shape().ndims, self.inputs.get_shape(), cell_fn.__name__) - ) - # You can get the dimension by .get_shape() or ._shape, and check the - # dimension by .with_rank() as follow. - # self.inputs.get_shape().with_rank(2) - # self.inputs.get_shape().with_rank(3) - - # Input dimension should be rank 5 [batch_size, n_steps(max), h, w, c] - try: - self.inputs.get_shape().with_rank(5) - except Exception: - raise Exception( - "RNN : Input dimension should be rank 5 : [batch_size, n_steps, input_x, " - "input_y, feature_map]" - ) - - fixed_batch_size = self.inputs.get_shape().with_rank_at_least(1)[0] - - if fixed_batch_size.value: - batch_size = fixed_batch_size.value - logging.info(" RNN batch_size (concurrent processes): %d" % batch_size) - + if act == 'tanh': + mode = 'RNN_TANH' + elif act == 'relu': + mode = 'RNN_RELU' else: - batch_size = array_ops.shape(self.inputs)[0] - logging.info(" non specified batch_size, uses a tensor instead.") - self.batch_size = batch_size - outputs = [] - self.cell = cell = cell_fn(shape=cell_shape, filter_size=filter_size, num_features=feature_map) - - if initial_state is None: - self.initial_state = cell.zero_state(batch_size, dtype=LayersConfig.tf_dtype) - else: - self.initial_state = initial_state - - state = self.initial_state - - # with tf.variable_scope("model", reuse=None, initializer=initializer): - with tf.compat.v1.variable_scope(name, initializer=initializer) as vs: - for time_step in range(n_steps): - if time_step > 0: tf.compat.v1.get_variable_scope().reuse_variables() - (cell_output, state) = cell(self.inputs[:, time_step, :, :, :], state) - outputs.append(cell_output) - - # Retrieve just the RNN variables. - # rnn_variables = [v for v in tf.all_variables() if v.name.startswith(vs.name)] - rnn_variables = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.VARIABLES, scope=vs.name) - - logging.info(" n_params : %d" % (len(rnn_variables))) - - if return_last: - # 2D Tensor [batch_size, n_hidden] - self.outputs = outputs[-1] - else: - if return_seq_2d: - # PTB tutorial: stack dense layer after that, or compute the cost from the output - # 4D Tensor [n_example, h, w, c] - self.outputs = tf.reshape(tf.concat(outputs, 1), [-1, cell_shape[0] * cell_shape[1] * feature_map]) - else: - # : stack more RNN layer after that - # 5D Tensor [n_example/n_steps, n_steps, h, w, c] - self.outputs = tf.reshape( - tf.concat(outputs, 1), [-1, n_steps, cell_shape[0], cell_shape[1], feature_map] - ) - - self.final_state = state - - self._add_layers(self.outputs) - self._add_params(rnn_variables) - -''' + raise ValueError("act should be in ['tanh', 'relu'], but got {}.".format(act)) + super(RNN, self + ).__init__(mode, input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, name) -# @tf.function -def retrieve_seq_length_op(data): - """An op to compute the length of a sequence from input shape of [batch_size, n_step(max), n_features], - it can be used when the features of padding (on right hand side) are all zeros. +class LSTM(RNNBase): + """Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence. Parameters - ----------- - data : tensor - [batch_size, n_step(max), n_features] with zero padding on right hand side. + ---------- + input_size : int + The number of expected features in the input `x` + hidden_size : int + The number of features in the hidden state `h` + num_layers : int + Number of recurrent layers. Default: 1 + bias : bool + If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` + batch_first : bool + If ``True``, then the input and output tensors are provided as `[batch_size, seq, input_size]`, Default: ``False`` + dropout : float + If non-zero, introduces a `Dropout` layer on the outputs of each LSTM layer except the last layer, + with dropout probability equal to `dropout`. Default: 0 + bidirectional : bool + If ``True``, becomes a bidirectional LSTM. Default: ``False`` + name : None or str + A unique layer name + -------------------------------------------------------- + inputs : tensor + the input sequence. if `batch_first` is True, the shape is `[batch_size, seq, input_size]`, else, the shape is `[seq, batch_size, input_size]`. + initial_states : tensor or None + the initial states. A tuple of tensor (h, c), the shape of each is `[num_layers * num_directions, batch_size, hidden_size]`.If initial_state is not given, zero initial states are used. + If the LSTM is Bidirectional, num_directions should be 2, else it should be 1. Default: None. + + Returns + ---------- + outputs : tensor + the output sequence. if `batch_first` is True, the shape is `[batch_size, seq, num_directions * hidden_size]`, + else, the shape is `[seq, batch_size, num_directions * hidden_size]`. + final_states : tensor + final states. A tuple of two tensor. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. Note that if the LSTM is Bidirectional, the forward states are (0,2,4,6,...) and + the backward states are (1,3,5,7,....). Examples - ----------- - Single feature + -------- + With TensorLayer - >>> data = [[[1],[2],[0],[0],[0]], - >>> [[1],[2],[3],[0],[0]], - >>> [[1],[2],[6],[1],[0]]] - >>> data = tf.convert_to_tensor(data, dtype=tf.float32) - >>> length = tl.layers.retrieve_seq_length_op(data) - [2 3 4] - - Multiple features - - >>> data = [[[1,2],[2,2],[1,2],[1,2],[0,0]], - >>> [[2,3],[2,4],[3,2],[0,0],[0,0]], - >>> [[3,3],[2,2],[5,3],[1,2],[0,0]]] - >>> data = tf.convert_to_tensor(data, dtype=tf.float32) - >>> length = tl.layers.retrieve_seq_length_op(data) - [4 3 4] - - References - ------------ - Borrow from `TFlearn `__. + >>> input = tl.layers.Input([23, 32, 16], name='input') + >>> prev_h = tl.layers.Input([4, 32, 32]) + >>> prev_c = tl.layers.Input([4, 32, 32]) + >>> cell = tl.layers.LSTM(input_size=16, hidden_size=32, bias=True, num_layers=2, bidirectional = True, batch_first=False, dropout=0, name='lstm_1') + >>> y, (h, c)= cell(input, (prev_h, prev_c)) + >>> print(y.shape) """ - with tf.name_scope('GetLength'): - used = tf.sign(tf.reduce_max(input_tensor=tf.abs(data), axis=2)) - length = tf.reduce_sum(input_tensor=used, axis=1) - return tf.cast(length, tf.int32) + def __init__( + self, + input_size, + hidden_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0.0, + bidirectional=False, + name=None, + ): + super(LSTM, self + ).__init__('LSTM', input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, name) -# @tf.function -def retrieve_seq_length_op2(data): - """An op to compute the length of a sequence, from input shape of [batch_size, n_step(max)], - it can be used when the features of padding (on right hand side) are all zeros. +class GRU(RNNBase): + """Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. Parameters - ----------- - data : tensor - [batch_size, n_step(max)] with zero padding on right hand side. + ---------- + input_size : int + The number of expected features in the input `x` + hidden_size : int + The number of features in the hidden state `h` + num_layers : int + Number of recurrent layers. Default: 1 + bias : bool + If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` + batch_first : bool + If ``True``, then the input and output tensors are provided as `[batch_size, seq, input_size]`, Default: ``False`` + dropout : float + If non-zero, introduces a `Dropout` layer on the outputs of each GRU layer except the last layer, + with dropout probability equal to `dropout`. Default: 0 + bidirectional : bool + If ``True``, becomes a bidirectional LSTM. Default: ``False`` + name : None or str + A unique layer name + -------------------------------------------------------- + inputs : tensor + the input sequence. if `batch_first` is True, the shape is `[batch_size, seq, input_size]`, else, the shape is `[seq, batch_size, input_size]`. + initial_states : tensor or None + the initial states. A tuple of tensor (h, c), the shape of each is `[num_layers * num_directions, batch_size, hidden_size]`.If initial_state is not given, zero initial states are used. + If the GRU is Bidirectional, num_directions should be 2, else it should be 1. Default: None. + + Returns + ---------- + outputs : tensor + the output sequence. if `batch_first` is True, the shape is `[batch_size, seq, num_directions * hidden_size]`, + else, the shape is `[seq, batch_size, num_directions * hidden_size]`. + final_states : tensor + final states. A tuple of two tensor. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. Note that if the GRU is Bidirectional, the forward states are (0,2,4,6,...) and + the backward states are (1,3,5,7,....). Examples - ----------- - >>> data = [[1,2,0,0,0], - >>> [1,2,3,0,0], - >>> [1,2,6,1,0]] - >>> data = tf.convert_to_tensor(data, dtype=tf.float32) - >>> length = tl.layers.retrieve_seq_length_op2(data) - tensor([2 3 4]) + -------- + With TensorLayer + + >>> input = tl.layers.Input([23, 32, 16], name='input') + >>> prev_h = tl.layers.Input([4, 32, 32]) + >>> cell = tl.layers.GRU(input_size=16, hidden_size=32, bias=True, num_layers=2, bidirectional = True, batch_first=False, dropout=0, name='GRU_1') + >>> y, h= cell(input, prev_h) + >>> print(y.shape) """ - return tf.reduce_sum(input_tensor=tf.cast(tf.greater(data, tf.zeros_like(data)), tf.int32), axis=1) - -# @tf.function -def retrieve_seq_length_op3(data, pad_val=0): - """An op to compute the length of a sequence, the data shape can be [batch_size, n_step(max)] or - [batch_size, n_step(max), n_features]. - - If the data has type of tf.string and pad_val is assigned as empty string (''), this op will compute the - length of the string sequence. - - Parameters - ----------- - data : tensor - [batch_size, n_step(max)] or [batch_size, n_step(max), n_features] with zero padding on the right hand side. - pad_val: - By default 0. If the data is tf.string, please assign this as empty string ('') - - Examples - ----------- - >>> data = [[[1],[2],[0],[0],[0]], - >>> [[1],[2],[3],[0],[0]], - >>> [[1],[2],[6],[1],[0]]] - >>> data = tf.convert_to_tensor(data, dtype=tf.float32) - >>> length = tl.layers.retrieve_seq_length_op3(data) - tensor([2, 3, 4]) - >>> data = [[[1,2],[2,2],[1,2],[1,2],[0,0]], - >>> [[2,3],[2,4],[3,2],[0,0],[0,0]], - >>> [[3,3],[2,2],[5,3],[1,2],[0,0]]] - >>> data = tf.convert_to_tensor(data, dtype=tf.float32) - >>> length = tl.layers.retrieve_seq_length_op3(data) - tensor([4, 3, 4]) - >>> data = [[1,2,0,0,0], - >>> [1,2,3,0,0], - >>> [1,2,6,1,0]] - >>> data = tf.convert_to_tensor(data, dtype=tf.float32) - >>> length = tl.layers.retrieve_seq_length_op3(data) - tensor([2, 3, 4]) - >>> data = [['hello','world','','',''], - >>> ['hello','world','tensorlayer','',''], - >>> ['hello','world','tensorlayer','2.0','']] - >>> data = tf.convert_to_tensor(data, dtype=tf.string) - >>> length = tl.layers.retrieve_seq_length_op3(data, pad_val='') - tensor([2, 3, 4]) - - """ - data_shape_size = data.get_shape().ndims - if data_shape_size == 3: - return tf.reduce_sum( - input_tensor=tf.cast(tf.reduce_any(input_tensor=tf.not_equal(data, pad_val), axis=2), dtype=tf.int32), - axis=1 - ) - elif data_shape_size == 2: - return tf.reduce_sum(input_tensor=tf.cast(tf.not_equal(data, pad_val), dtype=tf.int32), axis=1) - elif data_shape_size == 1: - raise ValueError("retrieve_seq_length_op3: data has wrong shape! Shape got ", data.get_shape().as_list()) - else: - raise ValueError( - "retrieve_seq_length_op3: handling data with num of dims %s hasn't been implemented!" % (data_shape_size) - ) - - -def target_mask_op(data, pad_val=0): - """ Return the mask of the input sequence data based on the padding values. - - Parameters - ----------- - data : tf.Tensor - A tensor with 2 or 3 dimensions. - pad_val: int, float, string, etc - The value that represent padding. By default, 0. For tf.string, you may use empty string. - - Examples - ----------- - >>> data = [['hello', 'world', '', '', ''], - >>> ['hello', 'world', 'tensorlayer', '', ''], - >>> ['hello', 'world', 'tensorlayer', '2.0', '']] - >>> data = tf.convert_to_tensor(data, dtype=tf.string) - >>> mask = tl.layers.target_mask_op(data, pad_val='') - >>> print(mask) - tf.Tensor( - [[1 1 0 0 0] - [1 1 1 0 0] - [1 1 1 1 0]], shape=(3, 5), dtype=int32) - >>> data = [[[1], [0], [0], [0], [0]], - >>> [[1], [2], [3], [0], [0]], - >>> [[1], [2], [0], [1], [0]]] - >>> data = tf.convert_to_tensor(data, dtype=tf.float32) - >>> mask = tl.layers.target_mask_op(data) - >>> print(mask) - tf.Tensor( - [[1 0 0 0 0] - [1 1 1 0 0] - [1 1 0 1 0]], shape=(3, 5), dtype=int32) - >>> data = [[[0,0],[2,2],[1,2],[1,2],[0,0]], - >>> [[2,3],[2,4],[3,2],[1,0],[0,0]], - >>> [[3,3],[0,1],[5,3],[1,2],[0,0]]] - >>> data = tf.convert_to_tensor(data, dtype=tf.float32) - >>> mask = tl.layers.target_mask_op(data) - >>> print(mask) - tf.Tensor( - [[0 1 1 1 0] - [1 1 1 1 0] - [1 1 1 1 0]], shape=(3, 5), dtype=int32) - """ - - if not isinstance(data, tf.Tensor): - raise AttributeError("target_mask_op: the type of input data should be tf.Tensor but got %s." % type(data)) - data_shape_size = data.get_shape().ndims - if data_shape_size == 3: - return tf.cast(tf.reduce_any(input_tensor=tf.not_equal(data, pad_val), axis=2), dtype=tf.int32) - elif data_shape_size == 2: - return tf.cast(tf.not_equal(data, pad_val), dtype=tf.int32) - elif data_shape_size == 1: - raise ValueError( - "target_mask_op: data_shape %s is not supported. " - "The shape of data should have 2 or 3 dims." % (data.get_shape()) - ) - else: - raise ValueError( - "target_mask_op: handling data_shape %s hasn't been implemented! " - "The shape of data should have 2 or 3 dims" % (data.get_shape()) - ) + def __init__( + self, + input_size, + hidden_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0.0, + bidirectional=False, + name=None, + ): + super(GRU, self + ).__init__('GRU', input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, name)