Update TensorLayer3

This commit is contained in:
Eric_lai 2021-08-09 09:34:23 +08:00
parent 0e8a2ae701
commit 6c053306ed
11 changed files with 1653 additions and 1202 deletions

View File

@ -0,0 +1,74 @@
#! /usr/bin/python
# -*- coding: utf-8 -*-
# The same set of code can switch the backend with one line
import os
os.environ['TL_BACKEND'] = 'tensorflow'
# os.environ['TL_BACKEND'] = 'mindspore'
# os.environ['TL_BACKEND'] = 'paddle'
import tensorlayer as tl
from tensorlayer.layers import Module
from tensorlayer.layers import Dense, LSTM, Embedding
from tensorlayer.dataflow import Dataset
import numpy as np
X_train, y_train, X_test, y_test = tl.files.load_imdb_dataset('data', nb_words=20000, test_split=0.2)
Seq_Len = 200
vocab_size = len(X_train) + 1
class imdbdataset(Dataset):
def __init__(self, X, y):
self.X = X
self.y = y
def __getitem__(self, index):
data = self.X[index]
data = np.concatenate([data[:Seq_Len], [0] * (Seq_Len - len(data))]).astype('int64') # set
label = self.y[index].astype('int64')
return data, label
def __len__(self):
return len(self.y)
class ImdbNet(Module):
def __init__(self):
super(ImdbNet, self).__init__()
self.embedding = Embedding(vocabulary_size=vocab_size, embedding_size=64)
self.lstm = LSTM(input_size=64, hidden_size=64)
self.dense1 = Dense(in_channels=64, n_units=64, act=tl.ReLU)
self.dense2 = Dense(in_channels=64, n_units=2)
def forward(self, x):
x = self.embedding(x)
x, _ = self.lstm(x)
x = tl.ops.reduce_mean(x, axis=1)
x = self.dense1(x)
x = self.dense2(x)
return x
n_epoch = 5
batch_size = 64
print_freq = 2
train_dataset = imdbdataset(X=X_train, y=y_train)
train_dataset = tl.dataflow.FromGenerator(
train_dataset, output_types=[tl.int64, tl.int64], column_names=['data', 'label']
)
train_loader = tl.dataflow.Dataloader(train_dataset, batch_size=batch_size, shuffle=True)
net = ImdbNet()
train_weights = net.trainable_weights
optimizer = tl.optimizers.Adam(1e-3)
metric = tl.metric.Accuracy()
loss_fn = tl.cost.softmax_cross_entropy_with_logits
model = tl.models.Model(network=net, loss_fn=loss_fn, optimizer=optimizer, metrics=metric)
model.train(n_epoch=n_epoch, train_dataset=train_loader, print_freq=print_freq, print_train_batch=True)

View File

@ -141,3 +141,7 @@ from .load_backend import Maximum
from .load_backend import Meshgrid
from .load_backend import BatchToSpace
from .load_backend import DepthToSpace
from .load_backend import rnncell
from .load_backend import lstmcell
from .load_backend import grucell
from .load_backend import rnnbase

View File

@ -720,15 +720,26 @@ def reduce_min(input_tensor, axis=None):
class Pad(Cell):
def __init__(self, paddings, mode="REFLECT"):
def __init__(self, paddings, mode="REFLECT", constant_values=0):
super(Pad, self).__init__()
if mode not in ["REFLECT", "SYMMETRIC"]:
if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']:
raise Exception("Unsupported mode: {}".format(mode))
self.pad = P.MirrorPad(mode=mode)
self.paddings = Tensor(paddings)
if mode == 'CONSTANT':
self.pad = P.Pad(paddings)
if constant_values-0 == 0:
pass
else:
raise NotImplementedError("constant_values can only be equal to 0.")
else:
self.pad = P.MirrorPad(mode=mode)
self.paddings = Tensor(np.array(self.paddings))
self.mode = mode
def construct(self, x):
return self.pad(x, self.paddings)
if self.mode == 'CONSTANT':
return self.pad(x)
else:
return self.pad(x, self.paddings)
def pad(tensor, paddings, mode='CONSTANT', constant_values=0):

View File

@ -1800,3 +1800,152 @@ class DorefaConv2D(Cell):
outputs = nchw_to_nhwc(outputs)
return outputs
class rnncell(Cell):
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act):
super(rnncell, self).__init__()
self.weight_ih = weight_ih
self.weight_hh = weight_hh
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.act_fn = P.ReLU() if act == 'relu' else P.Tanh()
self.transpose = P.Transpose()
def construct(self, input, h):
self.weight_ih = self.transpose(self.weight_ih, (1, 0))
i2h = P.matmul(input, self.weight_ih)
if self.bias_ih is not None:
i2h += self.bias_ih
self.weight_hh = self.transpose(self.weight_hh, (1, 0))
h2h = P.matmul(h, self.weight_hh)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self.act_fn(i2h + h2h)
return h, h
class lstmcell(Cell):
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
super(lstmcell, self).__init__()
self.weight_ih = weight_ih
self.weight_hh = weight_hh
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.gate_act_fn = P.Sigmoid()
self.act_fn = P.Tanh()
self.transpose = P.Transpose()
self.split = P.Split(axis=-1, output_num=4)
def construct(self, input, h, c):
self.weight_ih = self.transpose(self.weight_ih, (1, 0))
gates = P.matmul(input, self.weight_ih)
if self.bias_ih is not None:
gates += self.bias_ih
self.weight_hh = self.transpose(self.weight_hh, (1, 0))
gates += P.matmul(h, self.weight_hh)
if self.bias_hh is not None:
gates += self.bias_hh
gate_slices = self.split(gates)
i = self.gate_act_fn(gate_slices[0])
f = self.gate_act_fn(gate_slices[1])
o = self.gate_act_fn(gate_slices[3])
c = f * c + i * self.act_fn(gate_slices[2])
h = o * self.act_fn(c)
return h, h, c
class grucell(Cell):
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
super(grucell, self).__init__()
self.weight_ih = weight_ih
self.weight_hh = weight_hh
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.gate_act_fn = P.Sigmoid()
self.act_fn = P.Tanh()
self.transpose = P.Transpose()
self.split = P.Split(axis=-1, output_num=3)
def construct(self, input, h):
self.weight_ih = self.transpose(self.weight_ih, (1, 0))
x_gates = P.matmul(input, self.weight_ih)
if self.bias_ih is not None:
x_gates += self.bias_ih
self.weight_hh = self.transpose(self.weight_hh, (1, 0))
h_gates = P.matmul(h, self.weight_hh)
if self.bias_hh is not None:
h_gates += self.bias_hh
x_r, x_z, x_c = self.split(x_gates)
h_r, h_z, h_c = self.split(h_gates)
r = self.gate_act_fn(x_r + h_r)
z = self.gate_act_fn(x_r + h_z)
c = self.act_fn(x_c + r * h_c)
h = (h - c) * z + c
return h, h
class rnnbase(Cell):
def __init__(
self,
mode,
input_size,
hidden_size,
num_layers,
bias,
batch_first,
dropout,
bidirectional,
is_train,
):
super(rnnbase, self).__init__()
self.mode = mode
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirect = 2 if bidirectional else 1
self.batch_first = batch_first
if mode == 'LSTM':
self.lstm = ms.nn.LSTM(
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=bias,
batch_first=batch_first, dropout=dropout, bidirectional=bidirectional
)
elif mode == 'GRU':
raise NotImplementedError
elif mode == 'RNN_TANH':
raise NotImplementedError
elif mode == 'RNN_RELU':
raise NotImplementedError
self.zeros = P.Zeros()
def construct(self, input, states):
input_shape = input.shape
input_dtype = input.dtype
if self.mode == 'LSTM':
if self.batch_first:
batch_size = input_shape[0]
else:
batch_size = input_shape[1]
if states is None:
h = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype)
c = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype)
states = (h, c)
output, (h, c) = self.lstm(input, states)
return output, (h, c)

View File

@ -4,6 +4,7 @@
from __future__ import absolute_import, division, print_function
import paddle as pd
import paddle.nn as nn
import numpy as np
_dtypeDict = ["float16", "float32", "float64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]
# TODO NotImplemented
@ -325,7 +326,7 @@ class Reshape(object):
self.shape = shape
def __call__(self, tensor):
raise NotImplementedError
return pd.reshape(tensor, shape=self.shape)
def reshape(tensor, shape):
@ -352,7 +353,7 @@ class Concat(object):
self.axis = axis
def __call__(self, values):
raise NotImplementedError
return pd.concat(values, axis=self.axis)
def concat(values, axis):
@ -369,7 +370,7 @@ def concat(values, axis):
-------
A Tensor resulting from concatenation of the input tensors.
"""
raise NotImplementedError
return pd.concat(values, axis)
def convert_to_tensor(value, dtype=float32):
@ -407,16 +408,16 @@ def sqrt(x):
-------
A Tensor. Has the same type as x.
"""
raise NotImplementedError
return pd.sqrt(x)
class ReduceSum(object):
def __init__(self, axis):
pass
self.axis = axis
def construct(self, input):
pass
return pd.sum(input, axis=self.axis)
class ReduceMean(object):
@ -447,7 +448,7 @@ def reduce_mean(input_tensor, axis=None):
The reduced tensor.
"""
raise NotImplementedError
return pd.mean(input_tensor, axis)
class ReduceMax(object):
@ -478,7 +479,7 @@ def reduce_max(input_tensor, axis=None):
The reduced tensor.
"""
raise NotImplementedError
return pd.max(input_tensor, axis)
def reduce_min(input_tensor, axis=None):
@ -499,21 +500,47 @@ def reduce_min(input_tensor, axis=None):
-------
The reduced tensor.
"""
raise NotImplementedError
return pd.min(input_tensor, axis)
class Pad(object):
def __init__(self, paddings, mode="REFLECT"):
def __init__(self, paddings, mode="REFLECT", constant_values=0):
if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']:
raise Exception("Unsupported mode: {}".format(mode))
if mode == 'SYMMETRIC':
mode = 'EDGE'
raise NotImplementedError
self.paddings = paddings
self.mode = mode
self.mode = mode.lower()
self.constant_values = constant_values
def __call__(self, x):
raise NotImplementedError
if len(x.shape) == 3:
data_format = 'NLC'
self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format)
elif len(x.shape) == 4:
data_format = 'NHWC'
self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format)
elif len(x.shape) == 5:
data_format = 'NDHWC'
self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format)
else:
raise NotImplementedError('Please check the input shape.')
return pd.nn.functional.pad(x, self.paddings, self.mode, value=self.constant_values, data_format=data_format)
def correct_paddings(self, in_shape, paddings, data_format):
if in_shape == 3 and data_format == 'NLC':
correct_output = [paddings[1][0], paddings[1][1]]
elif in_shape == 4 and data_format == 'NHWC':
correct_output = [paddings[2][0], paddings[2][1],
paddings[1][0], paddings[1][1]]
elif in_shape == 5 and data_format == 'NDHWC':
correct_output = [paddings[3][0], paddings[3][1],
paddings[2][0], paddings[2][1],
paddings[1][0], paddings[1][1]]
else:
raise NotImplementedError('Does not support channels first')
return correct_output
def pad(tensor, paddings, mode='CONSTANT', constant_values=0):
@ -535,7 +562,7 @@ def pad(tensor, paddings, mode='CONSTANT', constant_values=0):
-------
A Tensor. Has the same type as tensor.
"""
raise NotImplementedError
return Pad(paddings, mode, constant_values)(tensor)
class Unstack(object):
@ -545,7 +572,7 @@ class Unstack(object):
self.num = num
def __call__(self, values):
raise NotImplementedError
return pd.unstack(values, self.axis, self.num)
class Stack(object):
@ -554,7 +581,7 @@ class Stack(object):
self.axis = axis
def __call__(self, values):
raise NotImplementedError
return pd.stack(values, self.axis)
def stack(values, axis=0):
@ -563,7 +590,7 @@ def stack(values, axis=0):
Parameters
----------
values : list
values : list or tuple
A list of Tensor objects with the same shape and type.
axis : int
An int. The axis to stack along. Defaults to the first dimension.
@ -573,7 +600,7 @@ def stack(values, axis=0):
-------
A stacked Tensor with the same type as values.
"""
raise NotImplementedError
return pd.stack(values, axis=axis)
class Meshgrid(object):
@ -583,10 +610,10 @@ class Meshgrid(object):
self.index = indexing
def __call__(self, inputs):
pass
return pd.meshgrid(inputs)
def meshgrid(x, y):
def meshgrid(*args, **kwargs):
"""
Broadcasts parameters for evaluation on an N-D grid.
@ -602,7 +629,7 @@ def meshgrid(x, y):
A list of N Tensors with rank N.
"""
pass
return pd.meshgrid(*args, **kwargs)
def range(start, limit=None, delta=1, dtype=None):
@ -626,16 +653,19 @@ def range(start, limit=None, delta=1, dtype=None):
-------
An 1-D Tensor of type dtype.
"""
raise NotImplementedError
return pd.arange(start, step=delta)
class ExpandDims(object):
def __init__(self, axis):
pass
self.axis = axis
def construct(self, input):
pass
input = convert_to_numpy(input)
output = np.expand_dims(input, axis=self.axis)
output = convert_to_tensor(output)
return output
def expand_dims(input, axis):
@ -655,7 +685,10 @@ def expand_dims(input, axis):
A Tensor with the same data as input, but its shape has an additional dimension of size 1 added.
"""
raise NotImplementedError
input = convert_to_numpy(input)
output = np.expand_dims(input, axis=axis)
output = convert_to_tensor(output)
return output
class Tile(object):
@ -664,7 +697,7 @@ class Tile(object):
pass
def __call__(self, input, multiples):
raise NotImplementedError
return pd.tile(input, multiples)
def tile(input, multiples):
@ -683,16 +716,16 @@ def tile(input, multiples):
-------
A Tensor. Has the same type as input.
"""
raise NotImplementedError
return pd.tile(input, multiples)
class Cast(object):
def __init__(self, dtype):
pass
self.dtype = dtype
def __call__(self, input):
pass
return pd.cast(input, self.dtype)
def cast(x, dtype):
@ -711,7 +744,7 @@ def cast(x, dtype):
-------
A Tensor or SparseTensor or IndexedSlices with same shape as x and same type as dtype.
"""
raise NotImplementedError
return pd.cast(x, dtype)
class Transpose(object):
@ -722,7 +755,7 @@ class Transpose(object):
raise ("The conjugate Parameters not supported")
def __call__(self, a):
raise NotImplementedError
return pd.transpose(a, self.perm)
def transpose(a, perm=None, conjugate=False):
@ -743,7 +776,7 @@ def transpose(a, perm=None, conjugate=False):
A transposed Tensor.
"""
raise NotImplementedError
return pd.transpose(a, perm)
def gather_nd(params, indices, batch_dims=0):
@ -764,7 +797,7 @@ def gather_nd(params, indices, batch_dims=0):
A Tensor. Has the same type as params.
"""
pass
return pd.gather_nd(params, indices)
def clip_by_value(t, clip_value_min, clip_value_max):
@ -785,7 +818,7 @@ def clip_by_value(t, clip_value_min, clip_value_max):
A clipped Tensor or IndexedSlices.
"""
pass
return pd.clip(t, clip_value_min, clip_value_max)
def split(value, num_or_size_splits, axis=0, num=None):
@ -796,7 +829,7 @@ def split(value, num_or_size_splits, axis=0, num=None):
----------
value : tensor
The Tensor to split.
num_or_size_splits : list
num_or_size_splits : list or tuple
Either an integer indicating the number of splits along split_dim or a 1-D integer Tensor or
Python list containing the sizes of each output tensor along split_dim.
axis : int
@ -808,33 +841,33 @@ def split(value, num_or_size_splits, axis=0, num=None):
-------
Tensor objects resulting from splitting value.
"""
pass
pd.split(value, num_or_size_splits, axis)
class Floor(object):
def __call__(self, *args, **kwargs):
raise NotImplementedError
def __call__(self, x):
return pd.floor(x)
def floor(x):
raise NotImplementedError
return pd.floor(x)
def gather(params, indices):
raise NotImplementedError
return pd.gather(params, indices)
def linspace(start, stop, num):
raise NotImplementedError
return pd.linspace(start, stop, num)
def slice(inputs, starts, sizes):
raise NotImplementedError
return pd.slice(inputs, starts=starts, ends=sizes)
def add_n(inputs):
raise NotImplementedError
return pd.add_n(inputs)
class OneHot(object):
@ -844,17 +877,19 @@ class OneHot(object):
self.dtype = dtype
def __call__(self, indices):
raise NotImplementedError
output = pd.nn.functional.one_hot(indices, self.depth)
return output
class L2Normalize(object):
def __init__(self, axis=None, epsilon=1e-12):
super(L2Normalize, self).__init__()
pass
self.axis = axis
self.epsilon = epsilon
def __call__(self, input, *args, **kwargs):
pass
def __call__(self, input):
return pd.nn.functional.normalize(x=input, p=2, axis=self.axis, epsilon=self.epsilon)
class EmbeddingLookup(object):
@ -862,7 +897,7 @@ class EmbeddingLookup(object):
def __init__(self, max_norm=None):
self.max_norm = max_norm
def __call__(self, params, ids, *args, **kwargs):
def __call__(self, params, ids):
pass

View File

@ -3,6 +3,12 @@
import paddle as pd
import paddle.nn.functional as F
import numpy as np
import paddle.fluid as fluid
from paddle.nn import initializer as I
from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph import Layer
def padding_format(padding):
@ -1308,3 +1314,386 @@ class DorefaConv2D(object):
def __call__(self, inputs, filters):
raise NotImplementedError
class rnncell(object):
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, bias, act):
self.weight_ih = weight_ih
self.weight_hh = weight_hh
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.bias = bias
self.act_fn = F.relu if act == 'relu' else F.tanh
def __call__(self, input, h):
i2h = pd.matmul(input, self.weight_ih, transpose_y=True)
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = pd.matmul(h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self.act_fn(i2h + h2h)
return h, h
class lstmcell(object):
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, bias):
self.weight_ih = weight_ih
self.weight_hh = weight_hh
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.bias = bias
self.gate_act_fn = F.sigmoid
self.act_fn = F.tanh
def __call__(self, inputs, h, c):
gates = pd.matmul(inputs, self.weight_ih, transpose_y=True)
if self.bias_ih is not None:
gates += self.bias_ih
gates += pd.matmul(h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
gates += self.bias_hh
gates_slices = pd.split(gates, num_or_sections=4, axis=-1)
i = self.gate_act_fn(gates_slices[0])
f = self.gate_act_fn(gates_slices[1])
o = self.gate_act_fn(gates_slices[3])
c = f * c + i * self.act_fn(gates_slices[2])
h = o * self.act_fn(c)
return h, h, c
class grucell(object):
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, bias):
self.weight_ih = weight_ih
self.weight_hh = weight_hh
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.bias = bias
self.gate_act_fn = F.sigmoid
self.act_fn = F.tanh
def __call__(self, input, h):
x_gates = pd.matmul(input, self.weight_ih, transpose_y=True)
if self.bias_ih is not None:
x_gates = x_gates + self.bias_ih
h_gates = pd.matmul(h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h_gates = h_gates + self.bias_hh
x_r, x_z, x_c = pd.split(x_gates, num_or_sections=3, axis=-1)
h_r, h_z, h_c = pd.split(h_gates, num_or_sections=3, axis=-1)
r = self.gate_act_fn(x_r + h_r)
z = self.gate_act_fn(x_z + h_z)
c = self.act_fn(x_c + r * h_c) # apply reset gate after mm
h = (h - c) * z + c
return h, h
def split_states(states, bidirectional=False, state_components=1):
r"""
Split states of RNN network into possibly nested list or tuple of
states of each RNN cells of the RNN network.
Parameters:
states (Tensor|tuple|list): the concatenated states for RNN network.
When `state_components` is 1, states in a Tensor with shape
`(L*D, N, C)` where `L` is the number of layers of the RNN
network, `D` is the number of directions of the RNN network(1
for unidirectional RNNs and 2 for bidirectional RNNs), `N` is
the batch size of the input to the RNN network, `C` is the
hidden size of the RNN network.
When `state_components` is larger than 1, `states` is a tuple of
`state_components` Tensors that meet the requirements described
above.
For SimpleRNNs and GRUs, `state_components` is 1, and for LSTMs,
`state_components` is 2.
bidirectional (bool): whether the state is of a bidirectional RNN
network. Defaults to False.
state_components (int): the number of the components of the states. see
`states` above. Defaults to 1.
Returns:
A nested list or tuple of RNN cell states.
If `bidirectional` is True, it can be indexed twice to get an RNN
cell state. The first index indicates the layer, the second index
indicates the direction.
If `bidirectional` is False, it can be indexed once to get an RNN
cell state. The index indicates the layer.
Note that if `state_components` is larger than 1, an RNN cell state
can be indexed one more time to get a tensor of shape(N, C), where
`N` is the batch size of the input to the RNN cell, and `C` is the
hidden size of the RNN cell.
"""
if state_components == 1:
states = pd.unstack(states)
if not bidirectional:
return states
else:
return list(zip(states[::2], states[1::2]))
else:
assert len(states) == state_components
states = tuple([pd.unstack(item) for item in states])
if not bidirectional:
return list(zip(*states))
else:
states = list(zip(*states))
return list(zip(states[::2], states[1::2]))
def concat_states(states, bidirectional=False, state_components=1):
r"""
Concatenate a possibly nested list or tuple of RNN cell states into a
compact form.
Parameters:
states (list|tuple): a possibly nested list or tuple of RNN cell
states.
If `bidirectional` is True, it can be indexed twice to get an
RNN cell state. The first index indicates the layer, the second
index indicates the direction.
If `bidirectional` is False, it can be indexed once to get an RNN
cell state. The index indicates the layer.
Note that if `state_components` is larger than 1, an RNN cell
state can be indexed one more time to get a tensor of shape(N, C),
where `N` is the batch size of the input to the RNN cell, and
`C` is the hidden size of the RNN cell.
bidirectional (bool): whether the state is of a bidirectional RNN
network. Defaults to False.
state_components (int): the number of the components of the states. see
`states` above. Defaults to 1.
Returns:
Concatenated states for RNN network.
When `state_components` is 1, states in a Tensor with shape
`(L\*D, N, C)` where `L` is the number of layers of the RNN
network, `D` is the number of directions of the RNN network(1 for
unidirectional RNNs and 2 for bidirectional RNNs), `N` is the batch
size of the input to the RNN network, `C` is the hidden size of the
RNN network.
"""
if state_components == 1:
return pd.stack(flatten(states))
else:
states = flatten(states)
componnets = []
for i in range(state_components):
componnets.append(states[i::state_components])
return tuple([pd.stack(item) for item in componnets])
class rnnbase(Layer):
def __init__(
self,
mode,
input_size,
hidden_size,
num_layers,
bias,
batch_first,
dropout,
bidirectional,
is_train,
):
super(rnnbase, self).__init__()
self.mode = mode
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.time_major = False if batch_first else True
self.dropout = dropout
self.bidirect = 2 if bidirectional else 1
self.state_components = 2 if mode == 'LSTM' else 1
self.rnn = pd.nn.LayerList()
RNN = pd.nn.RNN
BiRNN = pd.nn.BiRNN
weight_ih_attr = None
weight_hh_attr = None
if bias:
bias_ih_attr = None
bias_hh_attr = None
else:
bias_ih_attr = False
bias_hh_attr = False
kwargs = {
"weight_ih_attr": weight_ih_attr,
"weight_hh_attr": weight_hh_attr,
"bias_ih_attr": bias_ih_attr,
"bias_hh_attr": bias_hh_attr
}
if mode == "LSTM":
rnn_cls = pd.nn.LSTMCell
elif mode == "GRU":
rnn_cls = pd.nn.GRUCell
elif mode == 'RNN_TANH':
rnn_cls = pd.nn.SimpleRNNCell
kwargs["activation"] = 'tanh'
elif mode == 'RNN_RELU':
rnn_cls = pd.nn.SimpleRNNCell
kwargs["activation"] = 'relu'
if not bidirectional:
is_reverse = False
cell = rnn_cls(input_size, hidden_size, **kwargs)
self.rnn.append(RNN(cell, is_reverse, self.time_major))
for i in range(1, num_layers):
cell = rnn_cls(hidden_size, hidden_size, **kwargs)
self.rnn.append(RNN(cell, is_reverse, self.time_major))
else:
cell_fw = rnn_cls(input_size, hidden_size, **kwargs)
cell_bw = rnn_cls(input_size, hidden_size, **kwargs)
self.rnn.append(BiRNN(cell_fw, cell_bw, self.time_major))
for i in range(1, num_layers):
cell_fw = rnn_cls(2 * hidden_size, hidden_size, **kwargs)
cell_bw = rnn_cls(2 * hidden_size, hidden_size, **kwargs)
self.rnn.append(BiRNN(cell_fw, cell_bw, self.time_major))
self.could_use_cudnn = True
self.could_use_cudnn &= len(self.rnn.parameters()) == num_layers * 4 * self.bidirect
param_names = []
for layer in range(self.num_layers):
for direction in range(self.bidirect):
suffix = '_reverse' if direction == 1 else ''
param_names.extend(['weight_ih_l{}{}', 'weight_hh_l{}{}'])
if bias_ih_attr != False: param_names.append('bias_ih_l{}{}')
if bias_hh_attr != False: param_names.append('bias_hh_l{}{}')
param_names = [x.format(layer, suffix) for x in param_names]
for name, param in zip(param_names, self.rnn.parameters()):
setattr(self.rnn, name, param)
self.flatten_parameters()
def flatten_parameters(self):
"""
Resets parameter data pointer to address in continuous memory block for
cudnn usage.
"""
if self.could_use_cudnn:
# layer.parameters() is depth first and ordered
# for i in layer: for j in direct: w_ih, w_hh, b_ih, b_hh
# need to reorganize to cudnn param layout:
# all bias following all weights
params = self.rnn.parameters(include_sublayers=False)
shape = [np.prod(param.shape) for param in params]
self._all_weights = [None] * len(params)
for i, param in enumerate(params):
offset = 0 if i % 4 < 2 else (2 * self.num_layers * self.bidirect)
layer_idx = i // 4
self._all_weights[offset + layer_idx * 2 + i % 2] = param
# Wrap using a list to avoid registed into params and saving, maybe
# need a better way to handle this later. Use `create_parameter` to
# add both to main_program and startup_program for static-graph.
# Use Constant initializer to avoid make effect on random generator.
self._flat_weight = [
self.rnn.create_parameter(
shape=[np.sum(shape)], dtype=params[0].dtype, default_initializer=I.Constant(0.0)
)
]
# dropout state may also can be hided and avoid saving
# should dropout state be persistable for static-graph
self._dropout_state = self.rnn.create_variable(dtype=fluid.core.VarDesc.VarType.UINT8)
# for static-graph, append coalesce_tensor into startup program
with fluid.program_guard(fluid.default_startup_program(), fluid.default_startup_program()):
with pd.framework.no_grad():
self.rnn._helper.append_op(
type="coalesce_tensor", inputs={"Input": self._all_weights}, outputs={
"Output": self._all_weights,
"FusedOutput": self._flat_weight
}, attrs={
"copy_data": True,
"use_align": False,
"dtype": params[0].dtype
}
)
def _cudnn_impl(self, inputs, initial_states, sequence_length):
if not self.time_major:
inputs = pd.tensor.transpose(inputs, [1, 0, 2])
out = self.rnn._helper.create_variable_for_type_inference(inputs.dtype)
state = [
self.rnn._helper.create_variable_for_type_inference(inputs.dtype) for i in range(self.state_components)
]
reserve = self.rnn._helper.create_variable_for_type_inference(
dtype=fluid.core.VarDesc.VarType.UINT8, stop_gradient=True
)
inputs = {
'Input': inputs,
'WeightList': self._all_weights,
'PreState': initial_states,
'SequenceLength': sequence_length
}
attrs = {
'dropout_prob': self.dropout,
'is_bidirec': self.bidirect == 2,
'input_size': self.input_size,
'hidden_size': self.hidden_size,
'num_layers': self.num_layers,
'mode': self.mode,
'is_test': not self.rnn.training
}
outputs = {
'Out': out,
'State': state,
'Reserve': reserve,
'DropoutState': self._dropout_state,
}
self.rnn._helper.append_op(type="rnn", inputs=inputs, outputs=outputs, attrs=attrs)
out = pd.tensor.transpose(out, [1, 0, 2]) if not self.time_major else out
return out, tuple(state) if len(state) > 1 else state[0]
def forward(self, inputs, initial_states=None, sequence_length=None):
batch_index = 1 if self.time_major else 0
dtype = inputs.dtype
if initial_states is None:
state_shape = [self.num_layers * self.bidirect, -1, self.hidden_size]
if self.state_components == 1:
initial_states = fluid.layers.fill_constant_batch_size_like(
inputs, state_shape, dtype, 0, batch_index, 1
)
else:
initial_states = tuple(
[
fluid.layers.fill_constant_batch_size_like(inputs, state_shape, dtype, 0, batch_index, 1)
for _ in range(self.state_components)
]
)
if self.could_use_cudnn:
# Add CPU kernel and dispatch in backend later
return self._cudnn_impl(inputs, initial_states, sequence_length)
states = split_states(initial_states, self.bidirect == 2, self.state_components)
final_states = []
for i, rnn_layer in enumerate(self.rnn):
if i > 0:
inputs = F.dropout(inputs, self.dropout, training=self.rnn.training, mode="upscale_in_train")
outputs, final_state = rnn_layer(inputs, states[i], sequence_length)
final_states.append(final_state)
inputs = outputs
final_states = concat_states(final_states, self.bidirect == 2, self.state_components)
return outputs, final_states

View File

@ -531,14 +531,15 @@ def reduce_min(input_tensor, axis=None):
class Pad(object):
def __init__(self, paddings, mode="REFLECT"):
def __init__(self, paddings, mode="REFLECT", constant_values=0):
if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']:
raise Exception("Unsupported mode: {}".format(mode))
self.paddings = paddings
self.mode = mode
self.constant_values = constant_values
def __call__(self, x):
outputs = tf.pad(x, self.paddings, mode=self.mode, constant_values=0)
outputs = tf.pad(x, self.paddings, mode=self.mode, constant_values=self.constant_values)
return outputs
@ -884,7 +885,7 @@ class OneHot(object):
self.axis = axis
self.dtype = dtype
def __call__(self, inputs, *args, **kwargs):
def __call__(self, inputs):
outputs = tf.one_hot(
inputs, self.depth, on_value=self.on_value, off_value=self.off_value, axis=self.axis, dtype=self.dtype
)
@ -907,7 +908,7 @@ class EmbeddingLookup(object):
def __init__(self, max_norm=None):
self.max_norm = max_norm
def __call__(self, params, ids, *args, **kwargs):
def __call__(self, params, ids):
outputs = tf.nn.embedding_lookup(params=params, ids=ids, max_norm=self.max_norm)
return outputs

View File

@ -6,6 +6,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import moving_averages
from math import floor, ceil
import numpy as np
# loss function
sparse_softmax_cross_entropy_with_logits = tf.nn.sparse_softmax_cross_entropy_with_logits
sigmoid_cross_entropy_with_logits = tf.nn.sigmoid_cross_entropy_with_logits
@ -1913,3 +1914,342 @@ class DorefaConv2D(object):
)
return outputs
class rnncell(object):
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act):
self.weight_ih = weight_ih
self.weight_hh = weight_hh
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.act_fn = tf.nn.relu if act == 'relu' else tf.nn.tanh
def __call__(self, input, h, c=None):
i2h = tf.matmul(input, self.weight_ih, transpose_b=True)
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = tf.matmul(h, self.weight_hh, transpose_b=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self.act_fn(i2h + h2h)
return h, h
class lstmcell(object):
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act=None):
self.weight_ih = weight_ih
self.weight_hh = weight_hh
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.gate_act_fn = tf.sigmoid
self.act_fn = tf.tanh
def __call__(self, input, h, c):
gates = tf.matmul(input, self.weight_ih, transpose_b=True)
if self.bias_ih is not None:
gates = gates + self.bias_ih
gates += tf.matmul(h, self.weight_hh, transpose_b=True)
if self.bias_hh is not None:
gates += self.bias_hh
gate_slices = tf.split(gates, num_or_size_splits=4, axis=-1)
i = self.gate_act_fn(gate_slices[0])
f = self.gate_act_fn(gate_slices[1])
o = self.gate_act_fn(gate_slices[3])
c = f * c + i * self.act_fn(gate_slices[2])
h = o * self.act_fn(c)
return h, h, c
class grucell(object):
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act=None):
self.weight_ih = weight_ih
self.weight_hh = weight_hh
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.gate_act_fn = tf.sigmoid
self.act_fn = tf.tanh
def __call__(self, input, h, c=None):
x_gates = tf.matmul(input, self.weight_ih, transpose_b=True)
if self.bias_ih is not None:
x_gates = x_gates + self.bias_ih
h_gates = tf.matmul(h, self.weight_hh, transpose_b=True)
if self.bias_hh is not None:
h_gates = h_gates + self.bias_hh
x_r, x_z, x_c = tf.split(x_gates, num_or_size_splits=3, axis=-1)
h_r, h_z, h_c = tf.split(h_gates, num_or_size_splits=3, axis=-1)
r = self.gate_act_fn(x_r + h_r)
z = self.gate_act_fn(x_r + h_z)
c = self.act_fn(x_c + r * h_c)
h = (h - c) * z + c
return h, h
class rnnbase(object):
def __init__(
self,
mode,
input_size,
hidden_size,
num_layers,
bias,
batch_first,
dropout,
bidirectional,
is_train,
weights_fw,
weights_bw,
bias_fw,
bias_bw,
):
self.mode = mode
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = float(dropout)
self.train = is_train
if not 0 <= dropout < 1:
raise ValueError("dropout should be a number in range [0, 1).")
if dropout > 0 and num_layers == 1:
raise ValueError(
"dropout option adds dropout after all but last "
"recurrent layer, so non-zero dropout expects "
"num_layers greater than 1, but got dropout={} and "
"num_layers={}".format(dropout, num_layers)
)
self.bidirect = 2 if bidirectional else 1
self.weights_fw = weights_fw
self.bias_fw = bias_fw
self.weights_bw = weights_bw
self.bias_bw = bias_bw
# stdv = 1.0 / np.sqrt(self.hidden_size)
# _init = tf.random_uniform_initializer(minval=-stdv, maxval=stdv)
self.act_fn = None
if mode == 'LSTM':
# gate_size = 4 * hidden_size
self.rnn_cell = lstmcell
elif mode == 'GRU':
# gate_size = 3 * hidden_size
self.rnn_cell = grucell
elif mode == 'RNN_TANH':
# gate_size = hidden_size
self.rnn_cell = rnncell
self.act_fn = 'tanh'
elif mode == 'RNN_RELU':
# gate_size = hidden_size
self.rnn_cell = rnncell
self.act_fn = 'relu'
# for layer in range(num_layers):
# for direction in range(self.bidirect):
# layer_input_size = input_size if layer==0 else hidden_size*self.bidirect
# if direction == 0:
# self.w_ih = tf.Variable(initial_value= _init(shape=(gate_size, layer_input_size)),name = 'weight_ih_l'+str(layer), trainable=True)
# self.w_hh = tf.Variable(initial_value=_init(shape=(gate_size, hidden_size)),
# name='weight_hh_l'+str(layer), trainable=True)
# # self.w_ih = self.weights_init('weight_ih_l'+str(layer), shape = (gate_size, layer_input_size), init = _init)
# # self.w_hh = self.weights_init('weight_ih_l' + str(layer), shape=(gate_size, hidden_size),
# # init=_init)
# self.weights_fw.append(self.w_ih)
# self.weights_fw.append(self.w_hh)
# if bias:
# self.b_ih = tf.Variable(initial_value=_init(shape=(gate_size,)),
# name='bias_ih_l'+str(layer), trainable=True)
# self.b_hh = tf.Variable(initial_value=_init(shape=(gate_size,)),
# name='bias_hh_l'+str(layer), trainable=True)
# # self.b_ih = self.weights_init('bias_ih_l'+str(layer), shape=(gate_size,), init=_init)
# # self.b_hh = self.weights_init('bias_hh_l'+str(layer), shape=(gate_size,), init=_init)
# self.bias_fw.append(self.b_ih)
# self.bias_fw.append(self.b_hh)
# else:
# self.w_ih = tf.Variable(initial_value= _init(shape=(gate_size, layer_input_size)),name = 'weight_ih_l'+str(layer)+'_reverse', trainable=True)
# self.w_hh = tf.Variable(initial_value=_init(shape=(gate_size, hidden_size)),
# name='weight_hh_l'+str(layer)+'_reverse', trainable=True)
# # self.w_ih = self.weights_init('weight_ih_l'+str(layer)+'_reverse', shape = (gate_size, layer_input_size), init = _init)
# # self.w_hh = self.weights_init('weight_hh_l'+str(layer)+'_reverse', shape=(gate_size, hidden_size),
# # init=_init)
# self.weights_bw.append(self.w_ih)
# self.weights_bw.append(self.w_hh)
# if bias:
# self.b_ih = tf.Variable(initial_value=_init(shape=(gate_size,)),
# name='bias_ih_l'+str(layer)+'_reverse', trainable=True)
# self.b_hh = tf.Variable(initial_value=_init(shape=(gate_size,)),
# name='bias_hh_l'+str(layer)+'_reverse', trainable=True)
# # self.b_ih = self.weights_init('bias_ih_l'+str(layer)+'_reverse', shape=(gate_size,), init=_init)
# # self.b_hh = self.weights_init('bias_hh_l'+str(layer)+'_reverse', shape=(gate_size,), init=_init)
# self.bias_bw.append(self.b_ih)
# self.bias_bw.append(self.b_hh)
def _bi_rnn_forward(self, x, h, c=None):
time_step, batch_size, input_size = x.shape
h_out = []
c_out = []
y = []
pre_layer = x
for i in range(self.num_layers):
weight_ih_fw = self.weights_fw[2 * i]
weight_hh_fw = self.weights_fw[2 * i + 1]
weight_ih_bw = self.weights_bw[2 * i]
weight_hh_bw = self.weights_bw[2 * i + 1]
if self.bias:
bias_ih_fw = self.bias_fw[2 * i]
bias_hh_fw = self.bias_fw[2 * i + 1]
bias_ih_bw = self.bias_bw[2 * i]
bias_hh_bw = self.bias_bw[2 * i + 1]
else:
bias_ih_fw = None
bias_hh_fw = None
bias_ih_bw = None
bias_hh_bw = None
h_i_fw = h[i, :, :]
h_i_bw = h[i + 1, :, :]
if i != 0 and self.train:
pre_layer = tf.nn.dropout(pre_layer, rate=self.dropout)
if c is not None:
c_i_fw = c[i, :, :]
c_i_bw = c[i + 1, :, :]
for j in range(time_step):
input = pre_layer[j, :, :]
cell_fw = self.rnn_cell(weight_ih_fw, weight_hh_fw, bias_ih_fw, bias_hh_fw, self.act_fn)
cell_bw = self.rnn_cell(weight_ih_bw, weight_hh_bw, bias_ih_bw, bias_hh_bw, self.act_fn)
bw_input = tf.reverse(input, axis=[0])
step_out_fw, h_i_fw, c_i_fw = cell_fw(input, h_i_fw, c_i_fw)
step_out_bw, h_i_bw, c_i_bw = cell_bw(bw_input, h_i_bw, c_i_bw)
step_out_bw = tf.reverse(step_out_bw, axis=[0])
step_out = tf.concat([step_out_fw, step_out_bw], axis=-1)
y.append(step_out)
h_out.append(h_i_fw)
h_out.append(h_i_bw)
c_out.append(c_i_fw)
c_out.append(c_i_bw)
pre_layer = tf.stack(y)
y = []
else:
for j in range(time_step):
input = pre_layer[j, :, :]
cell_fw = self.rnn_cell(weight_ih_fw, weight_hh_fw, bias_ih_fw, bias_hh_fw, self.act_fn)
cell_bw = self.rnn_cell(weight_ih_bw, weight_hh_bw, bias_ih_bw, bias_hh_bw, self.act_fn)
bw_input = tf.reverse(input, axis=[0])
step_out_fw, h_i_fw = cell_fw(input, h_i_fw)
step_out_bw, h_i_bw = cell_bw(bw_input, h_i_bw)
step_out_bw = tf.reverse(step_out_bw, axis=[0])
step_out = tf.concat([step_out_fw, step_out_bw], axis=-1)
y.append(step_out)
h_out.append(h_i_fw)
h_out.append(h_i_bw)
pre_layer = tf.stack(y)
y = []
h_out = tf.stack(h_out)
c_out = tf.stack(c_out) if c is not None else None
return pre_layer, h_out, c_out
def _rnn_forward(self, x, h, c=None):
pre_layer = x
h_out = []
c_out = []
y = []
time_step, batch_size, input_size = x.shape
for i in range(self.num_layers):
weight_ih = self.weights_fw[2 * i]
weight_hh = self.weights_fw[2 * i + 1]
if self.bias:
bias_ih = self.bias_fw[2 * i]
bias_hh = self.bias_fw[2 * i + 1]
else:
bias_ih = None
bias_hh = None
h_i = h[i, :, :]
if i != 0 and self.train:
pre_layer = tf.nn.dropout(pre_layer, rate=self.dropout)
if c is not None:
c_i = c[i, :, :]
for j in range(time_step):
input = pre_layer[j, :, :]
cell = self.rnn_cell(weight_ih, weight_hh, bias_ih, bias_hh, self.act_fn)
step_out, h_i, c_i = cell(input, h_i, c_i)
y.append(step_out)
h_out.append(h_i)
c_out.append(c_i)
pre_layer = tf.stack(y)
y = []
else:
for j in range(time_step):
input = pre_layer[j, :, :]
cell = self.rnn_cell(weight_hh, weight_ih, bias_ih, bias_hh, self.act_fn)
step_out, h_i = cell(input, h_i)
y.append(step_out)
h_out.append(h_i)
pre_layer = tf.stack(y)
y = []
h_out = tf.stack(h_out)
c_out = tf.stack(c_out) if c is not None else None
return pre_layer, h_out, c_out
def check_input(self, input_shape):
if len(input_shape) != 3:
raise ValueError("input must have 3 dimensions. But got {}.".format(len(input_shape)))
if self.input_size != input_shape[-1]:
raise ValueError(
"The last dimension of input should be equal to input_size {}.But got {}".format(
self.input_size, input_shape[-1]
)
)
def check_hidden(self, h, batch_size):
expected_hidden_size = (self.num_layers * self.bidirect, batch_size, self.hidden_size)
if h.shape != expected_hidden_size:
raise ValueError('Expected hidden size {}, got {}.'.format(expected_hidden_size, h.shape))
def __call__(self, input, states):
if self.batch_first:
input = tf.transpose(input, perm=(1, 0, 2))
input_dtype = input.dtype
input_shape = input.shape
time_step, batch_size, input_size = input_shape
self.check_input(input_shape)
if self.mode == "LSTM":
if states is not None:
h, c = states
self.check_hidden(h, batch_size)
self.check_hidden(c, batch_size)
else:
h = tf.zeros(shape=(self.num_layers * self.bidirect, batch_size, self.hidden_size), dtype=input_dtype)
c = tf.zeros(shape=(self.num_layers * self.bidirect, batch_size, self.hidden_size), dtype=input_dtype)
if self.bidirect == 1:
y, new_h, new_c = self._rnn_forward(input, h, c)
else:
y, new_h, new_c = self._bi_rnn_forward(input, h, c)
new_states = (new_h, new_c)
else:
if states is not None:
h = states
self.check_hidden(h, batch_size)
else:
h = tf.zeros(shape=(self.num_layers * self.bidirect, batch_size, self.hidden_size), dtype=input_dtype)
if self.bidirect == 1:
y, new_h, _ = self._rnn_forward(input, h)
else:
y, new_h, _ = self._bi_rnn_forward(input, h)
new_states = new_h
if self.batch_first:
y = tf.transpose(y, perm=(1, 0, 2))
return y, new_states

View File

@ -1963,6 +1963,8 @@ def save_npz(save_list=None, name='model.npz'):
save_list_var = tf_variables_to_numpy(save_list)
elif tl.BACKEND == 'mindspore':
save_list_var = ms_variables_to_numpy(save_list)
elif tl.BACKEND == 'paddle':
save_list_var = pd_variables_to_numpy(save_list)
else:
raise NotImplementedError("This backend is not supported")
# print(name, save_list_var)
@ -2050,6 +2052,11 @@ def assign_weights(weights, network):
# net = Assign_net(network.all_weights[idx])
# net(assign_param)
Assign()(network.all_weights[idx], assign_param)
elif tl.BACKEND == 'paddle':
for idx, param in enumerate(weights):
assign_pd_variable(network.all_weights[idx], param)
else:
raise NotImplementedError ("This backend is not supported")
return ops

View File

@ -41,11 +41,13 @@ class PadLayer(Module):
self,
padding=None,
mode='CONSTANT',
constant_values=0,
name=None, # 'pad_layer',
):
super().__init__(name)
self.padding = padding
self.mode = mode
self.constant_values = constant_values
logging.info("PadLayer %s: padding: %s mode: %s" % (self.name, self.padding, self.mode))
@ -65,7 +67,7 @@ class PadLayer(Module):
return s.format(classname=self.__class__.__name__, **self.__dict__)
def build(self, inputs_shape=None):
self.pad = tl.ops.Pad(paddings=self.padding, mode=self.mode)
self.pad = tl.ops.Pad(paddings=self.padding, mode=self.mode, constant_values=self.constant_values)
def forward(self, inputs):
outputs = self.pad(inputs)

File diff suppressed because it is too large Load Diff