forked from TensorLayer/tensorlayer3
1952 lines
65 KiB
Python
1952 lines
65 KiB
Python
#! /usr/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import itertools
|
|
import mindspore as ms
|
|
import mindspore.ops as P
|
|
from mindspore import context
|
|
from mindspore.nn.cell import Cell
|
|
from mindspore._checkparam import Rel
|
|
from mindspore.ops import functional as F
|
|
from mindspore.communication import management
|
|
from mindspore.ops.operations import _inner_ops as inner
|
|
from mindspore._extends import cell_attr_register
|
|
from mindspore.ops._grad.grad_base import bprop_getters
|
|
from mindspore._checkparam import Validator as validator
|
|
from mindspore.communication.management import get_group_size, get_rank
|
|
|
|
|
|
def padding_format(padding):
|
|
"""
|
|
Checks that the padding format correspond format.
|
|
|
|
Parameters
|
|
----------
|
|
padding : str
|
|
Must be one of the following:"same", "SAME", "VALID", "valid"
|
|
|
|
Returns
|
|
-------
|
|
str "SAME" or "VALID"
|
|
"""
|
|
|
|
if padding in ["SAME", "same"]:
|
|
padding = "same"
|
|
elif padding in ["VALID", "valid"]:
|
|
padding = "valid"
|
|
elif padding == None:
|
|
padding = None
|
|
else:
|
|
raise Exception("Unsupported padding: " + str(padding))
|
|
return padding
|
|
|
|
|
|
def preprocess_1d_format(data_format, padding):
|
|
"""
|
|
Checks that the 1-D dataformat format correspond format.
|
|
|
|
Parameters
|
|
----------
|
|
data_format : str
|
|
Must be one of the following:"channels_last","NWC","NCW","channels_first"
|
|
padding : str
|
|
Must be one of the following:"same","valid","SAME","VALID"
|
|
|
|
Returns
|
|
-------
|
|
str "NWC" or "NCW" and "SAME" or "VALID"
|
|
"""
|
|
|
|
if data_format in ["channels_last", "NWC"]:
|
|
data_format = "NWC"
|
|
elif data_format in ["channels_first", "NCW"]:
|
|
data_format = "NCW"
|
|
elif data_format == None:
|
|
data_format = None
|
|
else:
|
|
raise Exception("Unsupported data format: " + str(data_format))
|
|
padding = padding_format(padding)
|
|
return data_format, padding
|
|
|
|
|
|
def preprocess_2d_format(data_format, padding):
|
|
"""
|
|
Checks that the 2-D dataformat format correspond format.
|
|
|
|
Parameters
|
|
----------
|
|
data_format : str
|
|
Must be one of the following:"channels_last","NHWC","NCHW","channels_first"
|
|
padding : str
|
|
Must be one of the following:"same","valid","SAME","VALID"
|
|
|
|
Returns
|
|
-------
|
|
str "NHWC" or "NCHW" and "SAME" or "VALID"
|
|
"""
|
|
|
|
if data_format in ["channels_last", "NHWC", "nhwc"]:
|
|
data_format = "NHWC"
|
|
elif data_format in ["channels_first", "NCHW", "nchw"]:
|
|
data_format = "NCHW"
|
|
elif data_format == None:
|
|
data_format = None
|
|
else:
|
|
raise Exception("Unsupported data format: " + str(data_format))
|
|
padding = padding_format(padding)
|
|
return data_format, padding
|
|
|
|
|
|
def preprocess_3d_format(data_format, padding):
|
|
"""
|
|
Checks that the 3-D dataformat format correspond format.
|
|
|
|
Parameters
|
|
----------
|
|
data_format : str
|
|
Must be one of the following:"channels_last","NDHWC","NCDHW","channels_first"
|
|
padding : str
|
|
Must be one of the following:"same","valid","SAME","VALID"
|
|
|
|
Returns
|
|
-------
|
|
str "NDHWC" or "NCDHW" and "SAME" or "VALID"
|
|
"""
|
|
|
|
if data_format in ['channels_last', 'NDHWC']:
|
|
data_format = 'NDHWC'
|
|
elif data_format in ['channels_first', 'NCDHW']:
|
|
data_format = 'NCDHW'
|
|
elif data_format == None:
|
|
data_format = None
|
|
else:
|
|
raise Exception("Unsupported data format: " + str(data_format))
|
|
padding = padding_format(padding)
|
|
return data_format, padding
|
|
|
|
|
|
def nchw_to_nhwc(x):
|
|
"""
|
|
Channels first to channels last
|
|
|
|
Parameters
|
|
----------
|
|
x : tensor
|
|
channels first tensor data
|
|
|
|
Returns
|
|
-------
|
|
channels last tensor data
|
|
"""
|
|
|
|
if len(P.Shape()(x)) == 3:
|
|
x = P.Transpose()(x, (0, 2, 1))
|
|
elif len(P.Shape()(x)) == 4:
|
|
x = P.Transpose()(x, (0, 2, 3, 1))
|
|
elif len(P.Shape()(x)) == 5:
|
|
x = P.Transpose()(x, (0, 2, 3, 4, 1))
|
|
# else:
|
|
# raise Exception("Unsupported dimensions")
|
|
return x
|
|
|
|
|
|
def nhwc_to_nchw(x):
|
|
"""
|
|
Channles last to channels first
|
|
|
|
Parameters
|
|
----------
|
|
x : tensor
|
|
channels last tensor data
|
|
|
|
Returns
|
|
-------
|
|
channels first tensor data
|
|
"""
|
|
|
|
if len(P.Shape()(x)) == 3:
|
|
x = P.Transpose()(x, (0, 2, 1))
|
|
elif len(P.Shape()(x)) == 4:
|
|
x = P.Transpose()(x, (0, 3, 1, 2))
|
|
elif len(P.Shape()(x)) == 5:
|
|
x = P.Transpose()(x, (0, 4, 1, 2, 3))
|
|
# else:
|
|
# raise Exception("Unsupported dimensions")
|
|
return x
|
|
|
|
|
|
class ReLU(Cell):
|
|
|
|
def __init__(self):
|
|
super(ReLU, self).__init__()
|
|
self.relu = P.ReLU()
|
|
|
|
def construct(self, x):
|
|
return self.relu(x)
|
|
|
|
|
|
def relu(x):
|
|
"""
|
|
Computes rectified linear: max(features, 0).
|
|
|
|
Parameters
|
|
----------
|
|
x : tensor
|
|
Must be one of the following types: float32, float64, int32, uint8, int16,
|
|
int8, int64, bfloat16, uint16, half, uint32, uint64, qint8.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor. Has the same type as features.
|
|
"""
|
|
outputs = P.ReLU()
|
|
return outputs(x)
|
|
|
|
|
|
class ReLU6(Cell):
|
|
|
|
def __init__(self):
|
|
super(ReLU6, self).__init__()
|
|
self.relu6 = P.ReLU6()
|
|
|
|
def construct(self, x):
|
|
return self.relu6(x)
|
|
|
|
|
|
def relu6(x):
|
|
"""
|
|
Computes Rectified Linear 6: min(max(features, 0), 6).
|
|
|
|
Parameters
|
|
----------
|
|
x : tensor
|
|
Must be one of the following types: float32, float64, int32, uint8, int16,
|
|
int8, int64, bfloat16, uint16, half, uint32, uint64, qint8.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor with the same type as features.
|
|
"""
|
|
outputs = P.ReLU6()
|
|
return outputs(x)
|
|
|
|
|
|
class LeakyReLU(Cell):
|
|
|
|
def __init__(self, alpha=0.2):
|
|
super(LeakyReLU, self).__init__()
|
|
self.leakyrelu = ms.nn.LeakyReLU(alpha=alpha)
|
|
|
|
def construct(self, x):
|
|
return self.leakyrelu(x)
|
|
|
|
|
|
def leaky_relu(x, alpha=0.2):
|
|
"""
|
|
Compute the Leaky ReLU activation function.
|
|
|
|
Parameters
|
|
----------
|
|
x : tensor
|
|
representing preactivation values. Must be one of the following types:
|
|
float16, float32, float64, int32, int64.
|
|
|
|
Returns
|
|
-------
|
|
The activation value.
|
|
"""
|
|
|
|
leaky_relu = LeakyReLU(alpha=alpha)
|
|
output = leaky_relu(x)
|
|
return leaky_relu
|
|
|
|
|
|
class Softplus(Cell):
|
|
|
|
def __init__(self):
|
|
super(Softplus, self).__init__()
|
|
self.softplus = P.Softplus()
|
|
|
|
def construct(self, x):
|
|
return self.softplus(x)
|
|
|
|
|
|
def softplus(x):
|
|
"""
|
|
Computes softplus: log(exp(features) + 1).
|
|
|
|
Parameters
|
|
----------
|
|
x : tensor
|
|
Must be one of the following types: half, bfloat16, float32, float64.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor. Has the same type as features.
|
|
"""
|
|
|
|
obj = Softplus()
|
|
return obj(x)
|
|
|
|
|
|
class Tanh(Cell):
|
|
|
|
def __init__(self):
|
|
super(Tanh, self).__init__()
|
|
self.tanh = P.Tanh()
|
|
|
|
def construct(self, x):
|
|
return self.tanh(x)
|
|
|
|
|
|
def tanh(x):
|
|
"""
|
|
Computes hyperbolic tangent of x element-wise.
|
|
|
|
Parameters
|
|
----------
|
|
x : tensor
|
|
Must be one of the following types: bfloat16, half, float32, float64, complex64, complex128.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor. Has the same type as x.
|
|
"""
|
|
|
|
_tanh = Tanh()
|
|
return _tanh(x)
|
|
|
|
|
|
class Sigmoid(Cell):
|
|
|
|
def __init__(self):
|
|
super(Sigmoid, self).__init__()
|
|
self.sigmoid = P.Sigmoid()
|
|
|
|
def construct(self, x):
|
|
return self.sigmoid(x)
|
|
|
|
|
|
def sigmoid(x):
|
|
"""
|
|
Computes sigmoid of x element-wise.
|
|
|
|
Parameters
|
|
----------
|
|
x : tensor
|
|
A Tensor with type float16, float32, float64, complex64, or complex128.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor with the same type as x.
|
|
"""
|
|
outputs = P.Sigmoid()
|
|
return outputs(x)
|
|
|
|
|
|
class Softmax(Cell):
|
|
|
|
def __init__(self):
|
|
super(Softmax, self).__init__()
|
|
self.softmax = P.Softmax()
|
|
|
|
def construct(self, x):
|
|
return self.softmax(x)
|
|
|
|
|
|
def softmax(logits, axis=None):
|
|
"""
|
|
Computes softmax activations.
|
|
|
|
Parameters
|
|
----------
|
|
logits : tensor
|
|
Must be one of the following types: half, float32, float64.
|
|
axis : int
|
|
The dimension softmax would be performed on. The default is -1 which indicates the last dimension.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor. Has the same type and shape as logits.
|
|
"""
|
|
outputs = P.Softmax(axis)
|
|
return outputs(logits)
|
|
|
|
|
|
class Dropout(Cell):
|
|
|
|
def __init__(self, keep, seed=0):
|
|
super(Dropout, self).__init__()
|
|
self.dropout = P.Dropout(keep_prob=keep)
|
|
self.is_gpu = context.get_context('device_target') in ["GPU"]
|
|
self.get_shape = P.Shape()
|
|
self.dropout_gen_mask = P.DropoutGenMask(Seed0=seed, Seed1=0)
|
|
self.dropout_do_mask = P.DropoutDoMask()
|
|
self.cast = P.Cast()
|
|
self.keep_prob = keep # ms.Tensor(keep, dtype=ms.float32)
|
|
# print(self.keep_prob, type(self.keep_prob))
|
|
|
|
def construct(self, inputs):
|
|
if self.is_gpu:
|
|
outputs, _ = self.dropout(inputs)
|
|
return outputs
|
|
if self.keep_prob == 1:
|
|
return inputs
|
|
shape = self.get_shape(inputs)
|
|
dtype = P.DType()(inputs)
|
|
if self._is_float_dtype(dtype):
|
|
keep_prob = self.cast(self.keep_prob, dtype=dtype)
|
|
else:
|
|
keep_prob = self.cast(self.keep_prob, ms.float16)
|
|
output = self.dropout_gen_mask(shape, keep_prob)
|
|
return self.dropout_do_mask(inputs, output, keep_prob)
|
|
|
|
def _is_float_dtype(dtype):
|
|
if dtype in [ms.float32, ms.float16]:
|
|
return True
|
|
return False
|
|
|
|
|
|
class BiasAdd(Cell):
|
|
"""
|
|
Adds bias to value.
|
|
|
|
Parameters
|
|
----------
|
|
x : tensor
|
|
A Tensor with type float, double, int64, int32, uint8, int16, int8, complex64, or complex128.
|
|
bias : tensor
|
|
Must be the same type as value unless value is a quantized type,
|
|
in which case a different quantized type may be used.
|
|
Returns
|
|
-------
|
|
A Tensor with the same type as value.
|
|
"""
|
|
|
|
def __init__(self, data_format='channels_first'):
|
|
super(BiasAdd, self).__init__()
|
|
self.bias_add = P.BiasAdd()
|
|
if data_format in ['channels_first', 'NCW', 'NCHW', 'NCDHW']:
|
|
self.data_format = 'channels_first'
|
|
elif data_format in ['channels_last', 'NWC', 'NHWC', 'NDHWC']:
|
|
self.data_format = 'channels_last'
|
|
else:
|
|
raise ("Unsupported data format: " + str(data_format))
|
|
|
|
def construct(self, x, bias):
|
|
if self.data_format == 'channels_last':
|
|
x = nhwc_to_nchw(x)
|
|
outputs = self.bias_add(x, bias)
|
|
if self.data_format == 'channels_last':
|
|
outputs = nchw_to_nhwc(outputs)
|
|
return outputs
|
|
|
|
|
|
def bias_add(x, bias):
|
|
"""
|
|
Adds bias to value.
|
|
|
|
Parameters
|
|
----------
|
|
x : tensor
|
|
A Tensor with type float, double, int64, int32, uint8, int16, int8, complex64, or complex128.
|
|
bias : tensor
|
|
Must be the same type as value unless value is a quantized type,
|
|
in which case a different quantized type may be used.
|
|
data_format : A string.
|
|
'N...C' and 'NC...' are supported.
|
|
name : str
|
|
A name for the operation (optional).
|
|
Returns
|
|
-------
|
|
A Tensor with the same type as value.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class Conv1D(Cell):
|
|
|
|
def __init__(self, stride, padding, data_format='NWC', dilations=None, out_channel=None, k_size=None):
|
|
super(Conv1D, self).__init__()
|
|
self.data_format, self.padding = preprocess_1d_format(data_format, padding)
|
|
self.stride = (1, stride)
|
|
self.dilations = (1, dilations)
|
|
self.k_size = (1, k_size)
|
|
self.out_channel = out_channel
|
|
|
|
self.conv2d = P.Conv2D(
|
|
out_channel=self.out_channel, kernel_size=self.k_size, pad_mode=self.padding, stride=self.stride,
|
|
dilation=self.dilations, mode=1, group=1
|
|
)
|
|
|
|
self.expand_dims = P.ExpandDims()
|
|
self.squeeze = P.Squeeze(2)
|
|
|
|
def construct(self, x, filters):
|
|
if self.data_format == 'NWC':
|
|
x = nhwc_to_nchw(x)
|
|
|
|
x = self.expand_dims(x, 2)
|
|
filters = self.expand_dims(filters, 2)
|
|
|
|
output = self.conv2d(x, filters)
|
|
output = self.squeeze(output)
|
|
|
|
if self.data_format == 'NWC':
|
|
output = nchw_to_nhwc(output)
|
|
return output
|
|
|
|
|
|
def conv1d(input, filters, stride, padding, data_format='NWC', dilations=None, name=None):
|
|
"""
|
|
Computes a 1-D convolution given 3-D input and filter tensors.
|
|
|
|
Parameters
|
|
----------
|
|
input : tensor
|
|
A 3D Tensor. Must be of type float16, float32, or float64
|
|
filters : tensor
|
|
A 3D Tensor. Must have the same type as input.
|
|
stride : int of list
|
|
An int or list of ints that has length 1 or 3. The number of entries by which the filter is moved right at each step.
|
|
padding : string
|
|
'SAME' or 'VALID'
|
|
data_format : string
|
|
An optional string from "NWC", "NCW". Defaults to "NWC", the data is stored in the order of
|
|
[batch, in_width, in_channels]. The "NCW" format stores data as [batch, in_channels, in_width].
|
|
dilations : int or list
|
|
An int or list of ints that has length 1 or 3 which defaults to 1.
|
|
The dilation factor for each dimension of input. If set to k > 1,
|
|
there will be k-1 skipped cells between each filter element on that dimension.
|
|
Dilations in the batch and depth dimensions must be 1.
|
|
name : string
|
|
A name for the operation (optional).
|
|
Returns
|
|
-------
|
|
A Tensor. Has the same type as input.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class Conv2D(Cell):
|
|
|
|
def __init__(self, strides, padding, data_format='NHWC', dilations=None, out_channel=None, k_size=None):
|
|
super(Conv2D, self).__init__()
|
|
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
|
|
|
|
if self.data_format is 'NHWC':
|
|
self.ms_stride = strides[1]
|
|
self.ms_dilation = dilations[1]
|
|
elif self.data_format is 'NCHW':
|
|
self.ms_stride = strides[2]
|
|
self.ms_dilation = dilations[2]
|
|
|
|
self.conv2d = P.Conv2D(
|
|
out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,
|
|
dilation=self.ms_dilation, mode=1, group=1, data_format=self.data_format
|
|
)
|
|
|
|
def construct(self, inputs, filters):
|
|
outputs = self.conv2d(inputs, filters)
|
|
return outputs
|
|
|
|
|
|
def conv2d(input, filters, strides, padding, data_format='NCHW', dilations=None):
|
|
"""
|
|
Computes a 2-D convolution given 4-D input and filters tensors.
|
|
|
|
Parameters
|
|
----------
|
|
input : tensor
|
|
Must be one of the following types: half, bfloat16, float32, float64. A 4-D tensor.
|
|
The dimension order is interpreted according to the value of data_format, see below for details.
|
|
filters : tensor
|
|
Must have the same type as input. A 4-D tensor of shape [filter_height, filter_width, in_channels, out_channels]
|
|
strides : int of list
|
|
The stride of the sliding window for each dimension of input. If a single value is given it is replicated in the H and W dimension.
|
|
By default the N and C dimensions are set to 1. The dimension order is determined by the value of data_format, see below for details.
|
|
padding : string
|
|
"SAME" or "VALID"
|
|
data_format : string
|
|
"NHWC", "NCHW". Defaults to "NCHW".
|
|
dilations : list or ints
|
|
list of ints that has length 1, 2 or 4, defaults to 1. The dilation factor for each dimension ofinput.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor. Has the same type as input.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class Conv3D(Cell):
|
|
|
|
def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_channel=None, k_size=None):
|
|
super(Conv3D, self).__init__()
|
|
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
|
|
|
|
if self.data_format is 'NDHWC':
|
|
self.ms_stride = strides[1]
|
|
self.ms_dilation = dilations[1]
|
|
raise NotImplementedError("The optional value for data format. Currently only support “NCDHW”.")
|
|
elif self.data_format is 'NCDHW':
|
|
self.ms_stride = strides[2]
|
|
self.ms_dilation = dilations[2]
|
|
|
|
self.conv3d = P.Conv3D(
|
|
out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,
|
|
dilation=self.ms_dilation, data_format=data_format
|
|
)
|
|
|
|
def construct(self, input, filters):
|
|
outputs = self.conv3d(input, filters)
|
|
return outputs
|
|
|
|
|
|
def conv3d(input, filters, strides, padding, data_format='NDHWC', dilations=None, name=None):
|
|
"""
|
|
Computes a 3-D convolution given 5-D input and filters tensors.
|
|
|
|
Parameters
|
|
----------
|
|
input : tensor
|
|
Must be one of the following types: half, bfloat16, float32, float64.
|
|
Shape [batch, in_depth, in_height, in_width, in_channels].
|
|
filters : tensor
|
|
Must have the same type as input. Shape [filter_depth, filter_height, filter_width, in_channels, out_channels].
|
|
in_channels must match between input and filters.
|
|
strides : list of ints
|
|
A list of ints that has length >= 5. 1-D tensor of length 5.
|
|
The stride of the sliding window for each dimension of input.
|
|
Must have strides[0] = strides[4] = 1.
|
|
padding : string
|
|
A string from: "SAME", "VALID". The type of padding algorithm to use.
|
|
data_format : string
|
|
An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". The data format of the input and output data.
|
|
With the default format "NDHWC", the data is stored in the order of: [batch, in_depth, in_height, in_width, in_channels].
|
|
Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width].
|
|
dilations : list of ints
|
|
Defaults to [1, 1, 1, 1, 1]. 1-D tensor of length 5. The dilation factor for each dimension of input.
|
|
If set to k > 1, there will be k-1 skipped cells between each filter element on that dimension.
|
|
The dimension order is determined by the value of data_format, see above for details.
|
|
Dilations in the batch and depth dimensions must be 1.
|
|
name : string
|
|
A name for the operation (optional).
|
|
|
|
Returns
|
|
-------
|
|
A Tensor. Has the same type as input.
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def lrn(inputs, depth_radius, bias, alpha, beta):
|
|
"""
|
|
Local Response Normalization.
|
|
|
|
Parameters
|
|
----------
|
|
inputs : tensor
|
|
Must be one of the following types: half, bfloat16, float32. 4-D.
|
|
depth_radius : int
|
|
Defaults to 5. 0-D. Half-width of the 1-D normalization window.
|
|
bias : float
|
|
Defaults to 1. An offset (usually positive to avoid dividing by 0).
|
|
alpha : float
|
|
Defaults to 1. A scale factor, usually positive.
|
|
beta : float
|
|
Defaults to 0.5. An exponent.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor. Has the same type as input.
|
|
"""
|
|
pass
|
|
|
|
|
|
def moments(x, axes, shift=None, keepdims=False):
|
|
"""
|
|
Calculates the mean and variance of x.
|
|
|
|
Parameters
|
|
----------
|
|
x : tensor
|
|
A Tensor
|
|
axes : ints
|
|
Axes along which to compute mean and variance.
|
|
shift : int
|
|
Not used in the current implementation.
|
|
keepdims : bool
|
|
produce moments with the same dimensionality as the input.
|
|
|
|
Returns
|
|
-------
|
|
Two Tensor objects: mean and variance.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class MaxPool1d(Cell):
|
|
|
|
def __init__(self, ksize, strides, padding, data_format=None):
|
|
super(MaxPool1d, self).__init__()
|
|
self.data_format, padding = preprocess_1d_format(data_format=data_format, padding=padding)
|
|
self.expand = P.ExpandDims()
|
|
_strides = (1, strides[0])
|
|
_ksize = (1, ksize[0])
|
|
if self.data_format == 'NWC':
|
|
self.squeeze = P.Squeeze(1)
|
|
_data_format = 'NHWC'
|
|
if self.data_format == 'NCW':
|
|
self.squeeze = P.Squeeze(2)
|
|
_data_format = 'NCHW'
|
|
|
|
self.max_pool = P.MaxPool(kernel_size=_ksize, strides=_strides, pad_mode=padding, data_format=_data_format)
|
|
|
|
def construct(self, inputs):
|
|
if self.data_format == 'NWC':
|
|
x = self.expand(inputs, 1)
|
|
if self.data_format == 'NCW':
|
|
x = self.expand(inputs, 2)
|
|
output = self.max_pool(x)
|
|
output = self.squeeze(output)
|
|
return output
|
|
|
|
|
|
class MaxPool(Cell):
|
|
|
|
def __init__(self, ksize, strides, padding, data_format=None):
|
|
super(MaxPool, self).__init__()
|
|
data_format, padding = preprocess_2d_format(data_format=data_format, padding=padding)
|
|
|
|
if data_format == 'NHWC':
|
|
_strides = (strides[1], strides[2])
|
|
if data_format == 'NCHW':
|
|
_strides = (strides[2], strides[3])
|
|
|
|
self.maxpool = P.MaxPool(kernel_size=ksize, strides=_strides, pad_mode=padding, data_format=data_format)
|
|
|
|
def construct(self, inputs):
|
|
outputs = self.maxpool(inputs)
|
|
return outputs
|
|
|
|
|
|
def max_pool(input, ksize, strides, padding, data_format=None):
|
|
"""
|
|
Performs the max pooling on the input.
|
|
|
|
Parameters
|
|
----------
|
|
input : tensor
|
|
Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels] if data_format does not start
|
|
with "NC" (default), or [batch_size, num_channels] + input_spatial_shape if data_format starts with "NC".
|
|
Pooling happens over the spatial dimensions only.
|
|
ksize : int or list of ints
|
|
An int or list of ints that has length 1, N or N+2.
|
|
The size of the window for each dimension of the input tensor.
|
|
strides : list or list of ints
|
|
An int or list of ints that has length 1, N or N+2.
|
|
The stride of the sliding window for each dimension of the input tensor.
|
|
padding : string
|
|
'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor of format specified by data_format. The max pooled output tensor.
|
|
"""
|
|
data_format, padding = preprocess_2d_format(data_format=data_format, padding=padding)
|
|
if data_format == 'NHWC':
|
|
_strides = (strides[1], strides[2])
|
|
if data_format == 'NCHW':
|
|
_strides = (strides[2], strides[3])
|
|
outputs = P.MaxPool(kernel_size=ksize, strides=_strides, pad_mode=padding, data_format=data_format)(input)
|
|
return outputs
|
|
|
|
|
|
class AvgPool1d(Cell):
|
|
|
|
def __init__(self, ksize, strides, padding, data_format=None):
|
|
super(AvgPool1d, self).__init__()
|
|
self.data_format, self.padding = preprocess_1d_format(data_format=data_format, padding=padding)
|
|
self.kernel_size = (1, ksize[0])
|
|
self.stride = (1, strides[0])
|
|
|
|
if self.data_format == 'NWC':
|
|
_data_format = 'NHWC'
|
|
self.squeeze = P.Squeeze(1)
|
|
if self.data_format == 'NCW':
|
|
_data_format = 'NCHW'
|
|
self.squeeze = P.Squeeze(2)
|
|
|
|
self.avg_pool = P.AvgPool(
|
|
kernel_size=self.kernel_size, strides=self.stride, pad_mode=self.padding, data_format=_data_format
|
|
)
|
|
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
|
self.slice = P.Slice()
|
|
self.expand = P.ExpandDims()
|
|
self.shape = P.Shape()
|
|
|
|
def construct(self, inputs):
|
|
x = inputs
|
|
batch, channel, width = self.shape(inputs)
|
|
if width == self.kernel_size[1]:
|
|
x = self.reduce_mean(x, 2)
|
|
elif width - self.kernel_size[1] < self.stride[1]:
|
|
x = self.slice(x, (0, 0, 0), (batch, channel, self.kernel_size[1]))
|
|
x = self.reduce_mean(x, 2)
|
|
else:
|
|
if self.data_format == 'NCW':
|
|
x = self.expand(x, 2)
|
|
if self.data_format == 'NWC':
|
|
x = self.expand(x, 1)
|
|
x = self.avg_pool(x)
|
|
x = self.squeeze(x)
|
|
return x
|
|
|
|
|
|
class AvgPool(Cell):
|
|
|
|
def __init__(self, ksize, strides, padding, data_format=None):
|
|
super(AvgPool, self).__init__()
|
|
self.data_format, self.padding = preprocess_2d_format(data_format=data_format, padding=padding)
|
|
ms_ksize = ksize[1]
|
|
ms_strides = strides[1]
|
|
self.avgpool = P.AvgPool(ksize=ms_ksize, strides=ms_strides, padding=padding, data_format=self.data_format)
|
|
|
|
def construct(self, inputs):
|
|
outputs = self.avgpool(inputs)
|
|
return outputs
|
|
|
|
|
|
def avg_pool(input, ksize, strides, padding):
|
|
"""
|
|
Performs the avg pooling on the input.
|
|
|
|
Parameters
|
|
----------
|
|
input : tensor
|
|
Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels]
|
|
if data_format does not start with "NC" (default), or [batch_size, num_channels] + input_spatial_shape
|
|
if data_format starts with "NC". Pooling happens over the spatial dimensions only.
|
|
ksize : int or list of ints
|
|
An int or list of ints that has length 1, N or N+2.
|
|
The size of the window for each dimension of the input tensor.
|
|
strides : int or list of ints
|
|
An int or list of ints that has length 1, N or N+2.
|
|
The stride of the sliding window for each dimension of the input tensor.
|
|
padding : string
|
|
'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor of format specified by data_format. The average pooled output tensor.
|
|
"""
|
|
padding = padding_format(padding)
|
|
ms_ksize = ksize[0]
|
|
ms_strides = strides[1]
|
|
outputs = P.AvgPool(ksize=ms_ksize, strides=ms_strides, padding=padding)
|
|
return outputs(input)
|
|
|
|
|
|
class MaxPool3d(Cell):
|
|
|
|
def __init__(self, ksize, strides, padding, data_format=None):
|
|
super(MaxPool3d, self).__init__()
|
|
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
|
|
if data_format == 'NDHWC':
|
|
_strides = (strides[1], strides[2], strides[3])
|
|
if data_format == 'NCDHW':
|
|
_strides = (strides[2], strides[3], strides[4])
|
|
self.max_pool3d = P.MaxPool3D(
|
|
kernel_size=ksize, strides=_strides, padding=padding, data_format=self.data_format
|
|
)
|
|
|
|
def __call__(self, inputs):
|
|
outputs = self.max_pool3d(inputs)
|
|
return outputs
|
|
|
|
|
|
def max_pool3d(input, ksize, strides, padding, data_format=None, name=None):
|
|
"""
|
|
Performs the max pooling on the input.
|
|
|
|
Parameters
|
|
----------
|
|
input : tensor
|
|
A 5-D Tensor of the format specified by data_format.
|
|
ksize : int or list of ints
|
|
An int or list of ints that has length 1, 3 or 5.
|
|
The size of the window for each dimension of the input tensor.
|
|
strides : int or list of ints
|
|
An int or list of ints that has length 1, 3 or 5.
|
|
The stride of the sliding window for each dimension of the input tensor.
|
|
padding : string
|
|
'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
|
|
data_format : string
|
|
"NDHWC", "NCDHW". Defaults to "NDHWC". The data format of the input and output data.
|
|
With the default format "NDHWC", the data is stored in the order of: [batch, in_depth, in_height, in_width, in_channels].
|
|
Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width].
|
|
name : string
|
|
A name for the operation (optional).
|
|
|
|
Returns
|
|
-------
|
|
A Tensor of format specified by data_format. The max pooled output tensor.
|
|
"""
|
|
pass
|
|
|
|
|
|
class AvgPool3d(Cell):
|
|
|
|
def __init__(self, ksize, strides, padding, data_format=None):
|
|
super(AvgPool3d, self).__init__()
|
|
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
|
|
if data_format == 'NDHWC':
|
|
_strides = (strides[1], strides[2], strides[3])
|
|
if data_format == 'NCDHW':
|
|
_strides = (strides[2], strides[3], strides[4])
|
|
raise NotImplementedError
|
|
|
|
def __call__(self, inputs):
|
|
pass
|
|
|
|
|
|
def avg_pool3d(input, ksize, strides, padding, data_format=None, name=None):
|
|
"""
|
|
Performs the average pooling on the input.
|
|
|
|
Parameters
|
|
----------
|
|
input : tensor
|
|
A 5-D Tensor of shape [batch, height, width, channels] and type float32, float64, qint8, quint8, or qint32.
|
|
ksize : int or list of ints
|
|
An int or list of ints that has length 1, 3 or 5. The size of the window for each dimension of the input tensor.
|
|
strides : int or list of ints
|
|
An int or list of ints that has length 1, 3 or 5.
|
|
The stride of the sliding window for each dimension of the input tensor.
|
|
padding : string
|
|
'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
|
|
data_format : string
|
|
'NDHWC' and 'NCDHW' are supported.
|
|
name : string
|
|
Optional name for the operation.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor with the same type as value. The average pooled output tensor.
|
|
"""
|
|
pass
|
|
|
|
|
|
def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_format=None, dilations=None, name=None):
|
|
"""
|
|
Performs an N-D pooling operation.
|
|
|
|
Parameters
|
|
----------
|
|
input : tensor
|
|
Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels]
|
|
if data_format does not start with "NC" (default), or [batch_size, num_channels] + input_spatial_shape
|
|
if data_format starts with "NC". Pooling happens over the spatial dimensions only.
|
|
window_shape : int
|
|
Sequence of N ints >= 1.
|
|
pooling_type : string
|
|
Specifies pooling operation, must be "AVG" or "MAX".
|
|
strides : ints
|
|
Sequence of N ints >= 1. Defaults to [1]*N. If any value of strides is > 1, then all values of dilation_rate must be 1.
|
|
padding : string
|
|
The padding algorithm, must be "SAME" or "VALID". Defaults to "SAME".
|
|
See the "returns" section of tf.ops.convolution for details.
|
|
data_format : string
|
|
Specifies whether the channel dimension of the input and output is the last dimension (default, or if data_format does not start with "NC"),
|
|
or the second dimension (if data_format starts with "NC").
|
|
For N=1, the valid values are "NWC" (default) and "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW".
|
|
For N=3, the valid values are "NDHWC" (default) and "NCDHW".
|
|
dilations : list of ints
|
|
Dilation rate. List of N ints >= 1. Defaults to [1]*N. If any value of dilation_rate is > 1, then all values of strides must be 1.
|
|
name : string
|
|
Optional. Name of the op.
|
|
|
|
Returns
|
|
-------
|
|
Tensor of rank N+2, of shape [batch_size] + output_spatial_shape + [num_channels]
|
|
"""
|
|
pass
|
|
|
|
|
|
class DepthwiseConv2d(Cell):
|
|
|
|
def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1):
|
|
super(DepthwiseConv2d, self).__init__()
|
|
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
|
|
self.ms_stride = strides[1]
|
|
self.ms_dilation = dilations[1]
|
|
self.depthwise_conv2d = P.DepthwiseConv2dNative(
|
|
channel_multiplier=channel_multiplier, kernel_size=ksize, stride=self.ms_stride, dilation=self.ms_dilation
|
|
)
|
|
|
|
def construct(self, input, filter):
|
|
if self.data_format == 'NHWC':
|
|
input = nhwc_to_nchw(input)
|
|
outputs = self.depthwise_conv2d(input, filter)
|
|
if self.data_format == 'NHWC':
|
|
outputs = nchw_to_nhwc(outputs)
|
|
return outputs
|
|
|
|
|
|
def depthwise_conv2d(input, filter, strides, padding, data_format=None, dilations=None, name=None):
|
|
"""
|
|
Depthwise 2-D convolution.
|
|
|
|
Parameters
|
|
----------
|
|
input : tensor
|
|
4-D with shape according to data_format.
|
|
filter : tensor
|
|
4-D with shape [filter_height, filter_width, in_channels, channel_multiplier].
|
|
strides : list
|
|
1-D of size 4. The stride of the sliding window for each dimension of input.
|
|
padding : string
|
|
'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
|
|
data_format : string
|
|
The data format for input. Either "NHWC" (default) or "NCHW".
|
|
dilations : list
|
|
1-D of size 2. The dilation rate in which we sample input values across the height and width dimensions in atrous convolution.
|
|
If it is greater than 1, then all values of strides must be 1.
|
|
name : string
|
|
A name for this operation (optional).
|
|
|
|
Returns
|
|
-------
|
|
A 4-D Tensor with shape according to data_format.
|
|
E.g., for "NHWC" format, shape is [batch, out_height, out_width, in_channels * channel_multiplier].
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class Conv1d_transpose(Cell):
|
|
|
|
def __init__(self, stride, padding, data_format, dilations=None, out_channel=None, k_size=None, in_channels=None):
|
|
super(Conv1d_transpose, self).__init__()
|
|
self.data_format, self.padding = preprocess_1d_format(data_format, padding)
|
|
self.in_channels = in_channels
|
|
self.out_channel = out_channel
|
|
self.stride = (1, stride)
|
|
self.dilations = (1, dilations)
|
|
self.k_size = (1, k_size)
|
|
if self.data_format == 'NWC':
|
|
self.data_format = 'NHWC'
|
|
self.h_axis = 1
|
|
else:
|
|
self.data_format = 'NCHW'
|
|
self.h_axis = 2
|
|
self.conv2d_transpose = P.Conv2DBackpropInput(
|
|
out_channel=self.in_channels, kernel_size=self.k_size, pad_mode=self.padding, stride=self.stride,
|
|
dilation=self.dilations, mode=1, group=1, data_format=self.data_format
|
|
)
|
|
self.shape = P.Shape()
|
|
self.expand_dims = P.ExpandDims()
|
|
self.squeeze = P.Squeeze(self.h_axis)
|
|
|
|
def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size):
|
|
length = 0
|
|
filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
|
|
|
|
if self.padding == 'same':
|
|
length = input_length * stride_size
|
|
elif self.padding == 'valid':
|
|
length = input_length * stride_size + max(filter_size - stride_size, 0)
|
|
|
|
return length
|
|
|
|
def construct(self, x, filters):
|
|
|
|
x = self.expand_dims(x, self.h_axis)
|
|
filters = self.expand_dims(filters, self.h_axis)
|
|
if self.data_format == 'NCHW':
|
|
n, _, h, w = self.shape(x)
|
|
else:
|
|
n, h, w, _ = self.shape(x)
|
|
h_out = self._deconv_output_length(h, self.k_size[0], self.stride[0], self.dilations[0])
|
|
w_out = self._deconv_output_length(w, self.k_size[1], self.stride[1], self.dilations[1])
|
|
if self.data_format == 'NCHW':
|
|
output_size = (n, self.out_channel, h_out, w_out)
|
|
else:
|
|
output_size = (n, h_out, w_out, self.out_channel)
|
|
output = self.conv2d_transpose(x, filters, output_size)
|
|
output = self.squeeze(output)
|
|
|
|
return output
|
|
|
|
|
|
def conv1d_transpose(
|
|
input, filters, output_shape, strides, padding='SAME', data_format='NWC', dilations=None, name=None
|
|
):
|
|
"""
|
|
The transpose of conv1d.
|
|
|
|
Parameters
|
|
----------
|
|
input : tensor
|
|
A 3-D Tensor of type float and shape [batch, in_width, in_channels]
|
|
for NWC data format or [batch, in_channels, in_width] for NCW data format.
|
|
filters : tensor
|
|
A 3-D Tensor with the same type as value and shape [filter_width, output_channels, in_channels].
|
|
filter's in_channels dimension must match that of value.
|
|
output_shape : tensor
|
|
A 1-D Tensor, containing three elements, representing the output shape of the deconvolution op.
|
|
strides : list
|
|
An int or list of ints that has length 1 or 3. The number of entries by which the filter is moved right at each step.
|
|
padding : string
|
|
'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
|
|
data_format : string
|
|
'NWC' and 'NCW' are supported.
|
|
dilations : list
|
|
An int or list of ints that has length 1 or 3 which defaults to 1.
|
|
The dilation factor for each dimension of input. If set to k > 1,
|
|
there will be k-1 skipped cells between each filter element on that dimension.
|
|
Dilations in the batch and depth dimensions must be 1.
|
|
name : string
|
|
Optional name for the returned tensor.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor with the same type as value.
|
|
"""
|
|
pass
|
|
|
|
|
|
class Conv2d_transpose(Cell):
|
|
|
|
def __init__(self, strides, padding, data_format, dilations=None, out_channel=None, k_size=None, in_channels=None):
|
|
super(Conv2d_transpose, self).__init__()
|
|
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
|
|
self.in_channels = in_channels
|
|
self.out_channel = out_channel
|
|
self.k_size = k_size
|
|
self.strides = strides
|
|
self.dilations = dilations
|
|
|
|
self.conv2d_transpose = P.Conv2DBackpropInput(
|
|
out_channel=self.in_channels, kernel_size=self.k_size, pad_mode=self.padding, stride=self.strides,
|
|
dilation=self.dilations, mode=1, group=1, data_format=self.data_format
|
|
)
|
|
self.shape = P.Shape()
|
|
|
|
def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size):
|
|
length = 0
|
|
filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
|
|
|
|
if self.padding == 'same':
|
|
length = input_length * stride_size
|
|
elif self.padding == 'valid':
|
|
length = input_length * stride_size + max(filter_size - stride_size, 0)
|
|
|
|
return length
|
|
|
|
def construct(self, x, filters):
|
|
if self.data_format == 'NHWC':
|
|
h_axis, w_axis = 1, 2
|
|
n, h, w, _ = self.shape(x)
|
|
else:
|
|
h_axis, w_axis = 2, 3
|
|
n, _, h, w = self.shape(x)
|
|
|
|
if isinstance(self.strides, int):
|
|
strides_h = self.strides
|
|
strides_w = self.strides
|
|
else:
|
|
strides_list = list(self.strides)
|
|
if len(strides_list) == 2:
|
|
strides_h = strides_list[0]
|
|
strides_w = strides_list[1]
|
|
elif len(strides_list) == 4:
|
|
strides_h = strides_list[h_axis]
|
|
strides_w = strides_list[w_axis]
|
|
|
|
if self.dilations is not None:
|
|
if isinstance(self.dilations, int):
|
|
dilations_h = self.dilations
|
|
dilations_w = self.dilations
|
|
else:
|
|
dilations_list = list(self.dilations)
|
|
if len(dilations_list) == 2:
|
|
dilations_h = dilations_list[0]
|
|
dilations_w = dilations_list[1]
|
|
elif len(dilations_list) == 4:
|
|
dilations_h = dilations_list[h_axis]
|
|
dilations_w = dilations_list[w_axis]
|
|
|
|
h_out = self._deconv_output_length(h, self.k_size[0], strides_h, dilations_h)
|
|
w_out = self._deconv_output_length(w, self.k_size[1], strides_w, dilations_w)
|
|
|
|
if self.data_format == 'NCHW':
|
|
output_size = (n, self.out_channel, h_out, w_out)
|
|
else:
|
|
output_size = (n, h_out, w_out, self.out_channel)
|
|
output = self.conv2d_transpose(x, filters, output_size)
|
|
|
|
return output
|
|
|
|
|
|
def conv2d_transpose(
|
|
input, filters, output_shape, strides, padding='SAME', data_format='NHWC', dilations=None, name=None
|
|
):
|
|
"""
|
|
The transpose of conv2d.
|
|
|
|
Parameters
|
|
----------
|
|
input : tensor
|
|
A 4-D Tensor of type float and shape [batch, height, width, in_channels]
|
|
for NHWC data format or [batch, in_channels, height, width] for NCHW data format.
|
|
filters : tensor
|
|
A 4-D Tensor with the same type as input and shape [height, width,
|
|
output_channels, in_channels]. filter's in_channels dimension must match that of input.
|
|
output_shape : tensor
|
|
A 1-D Tensor representing the output shape of the deconvolution op.
|
|
strides : list
|
|
An int or list of ints that has length 1, 2 or 4. The stride of the sliding window for each dimension of input.
|
|
If a single value is given it is replicated in the H and W dimension.
|
|
By default the N and C dimensions are set to 0.
|
|
The dimension order is determined by the value of data_format, see below for details.
|
|
padding : string
|
|
'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
|
|
data_format : string
|
|
'NHWC' and 'NCHW' are supported.
|
|
dilations : list
|
|
An int or list of ints that has length 1, 2 or 4, defaults to 1.
|
|
name : string
|
|
Optional name for the returned tensor.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor with the same type as input.
|
|
"""
|
|
pass
|
|
|
|
|
|
class Conv3d_transpose(Cell):
|
|
|
|
def __init__(
|
|
self, strides, padding, data_format='NDHWC', dilations=None, name=None, out_channel=None, k_size=None,
|
|
in_channels=None
|
|
):
|
|
super(Conv3d_transpose, self).__init__()
|
|
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
|
|
|
|
self.conv3d_transpose = P.Conv3DTranspose(
|
|
in_channel=in_channels, out_channel=out_channel, kernel_size=k_size, mode=1, pad_mode=self.padding,
|
|
stride=strides, dilation=dilations, data_format=self.data_format
|
|
)
|
|
|
|
def construct(self, input, filters):
|
|
output = self.conv3d_transpose(input, filters)
|
|
return output
|
|
|
|
|
|
def conv3d_transpose(
|
|
input, filters, output_shape, strides, padding='SAME', data_format='NDHWC', dilations=None, name=None
|
|
):
|
|
"""
|
|
The transpose of conv3d.
|
|
|
|
Parameters
|
|
----------
|
|
input : tensor
|
|
A 5-D Tensor of type float and shape [batch, height, width, in_channels] for
|
|
NHWC data format or [batch, in_channels, height, width] for NCHW data format.
|
|
filters : tensor
|
|
A 5-D Tensor with the same type as value and shape [height, width, output_channels, in_channels].
|
|
filter's in_channels dimension must match that of value.
|
|
output_shape : tensor
|
|
A 1-D Tensor representing the output shape of the deconvolution op.
|
|
strides : list
|
|
An int or list of ints that has length 1, 3 or 5.
|
|
padding : string
|
|
'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
|
|
data_format : string
|
|
'NDHWC' and 'NCDHW' are supported.
|
|
dilations : list of ints
|
|
An int or list of ints that has length 1, 3 or 5, defaults to 1.
|
|
name : string
|
|
Optional name for the returned tensor.
|
|
|
|
Returns
|
|
-------
|
|
A Tensor with the same type as value.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class BatchNorm(Cell):
|
|
"""Batch Normalization base class."""
|
|
|
|
@cell_attr_register
|
|
def __init__(
|
|
self, num_features, epsilon=1e-5, decay=0.9, gamma=None, beta=None, moving_mean=None, moving_var=None,
|
|
is_train=None, device_num_each_group=1, process_groups=0, data_format='NCHW'
|
|
):
|
|
super(BatchNorm, self).__init__()
|
|
if data_format in ["channels_last", "NHWC", "nhwc"]:
|
|
data_format = "NHWC"
|
|
elif data_format in ["channels_first", "NCHW", "nchw"]:
|
|
data_format = "NCHW"
|
|
validator.check_value_type('num_features', num_features, [int], self.cls_name)
|
|
if num_features < 1:
|
|
raise ValueError("num_features must be at least 1")
|
|
|
|
if decay < 0 or decay > 1:
|
|
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(decay))
|
|
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
|
|
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
|
raise ValueError("NHWC format only support in GPU target.")
|
|
self.use_batch_statistics = is_train
|
|
self.num_features = num_features
|
|
self.eps = epsilon
|
|
self.moving_mean = moving_mean
|
|
self.moving_variance = moving_var
|
|
self.gamma = gamma
|
|
self.beta = beta
|
|
self.group_device_num = validator.check_positive_int(device_num_each_group)
|
|
self.process_groups = process_groups
|
|
self.is_global = False
|
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
global SYNC_BN_GROUP_NAME
|
|
# for GlobalBatchNorm
|
|
if self.group_device_num != 1:
|
|
self.rank_id = get_rank()
|
|
self.rank_size = get_group_size()
|
|
self.device_list = [i for i in range(0, self.rank_size)]
|
|
self.rank_list = self.list_group(self.device_list, self.group_device_num)
|
|
self.rank_list_idx = len(self.rank_list)
|
|
for i in range(self.rank_list_idx):
|
|
if self.rank_id in self.rank_list[i]:
|
|
self.is_global = True
|
|
if SYNC_BN_GROUP_NAME == "":
|
|
SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
|
|
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
|
|
# for SyncBatchNorm
|
|
if self.process_groups != 0:
|
|
self.rank_id = get_rank()
|
|
self.rank_size = get_group_size()
|
|
if self.process_groups is not None:
|
|
validator.check_isinstance("process_groups", self.process_groups, list)
|
|
self._check_rank_ids(self.process_groups, self.rank_size)
|
|
for i in range(len(self.process_groups)):
|
|
validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list)
|
|
self.group_device_num = len(self.process_groups[i])
|
|
if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
|
|
self.is_global = True
|
|
if SYNC_BN_GROUP_NAME == "":
|
|
SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
|
|
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
|
|
elif self.rank_size > 1:
|
|
self.is_global = True
|
|
self.group_device_num = self.rank_size
|
|
self.device_list = [i for i in range(0, self.rank_size)]
|
|
if SYNC_BN_GROUP_NAME == "":
|
|
SYNC_BN_GROUP_NAME = "sync_bn_group0"
|
|
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
|
|
|
|
self.shape = P.Shape()
|
|
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
|
self.square = P.Square()
|
|
self.sqrt = P.Sqrt()
|
|
self.cast = P.Cast()
|
|
self.dtype = P.DType()
|
|
self.reshape = P.Reshape()
|
|
self._target = context.get_context("device_target")
|
|
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
|
|
self.momentum = 1.0 - decay
|
|
if context.get_context("enable_ge"):
|
|
self.is_ge_backend = True
|
|
else:
|
|
self.is_ge_backend = False
|
|
|
|
self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps, momentum=self.momentum, data_format=self.format)
|
|
if self.is_global:
|
|
self.bn_train = inner.SyncBatchNorm(
|
|
epsilon=self.eps, momentum=self.momentum, group=SYNC_BN_GROUP_NAME, device_num=self.group_device_num
|
|
)
|
|
|
|
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
|
|
|
|
data_parallel_strategy = ((1, ), (1, ))
|
|
data_parallel_strategy_one = ((1, ), ())
|
|
self.sub_mean = P.Sub().shard(data_parallel_strategy)
|
|
self.sub_var = P.Sub().shard(data_parallel_strategy)
|
|
self.mul_mean = P.Mul().shard(data_parallel_strategy_one)
|
|
self.mul_var = P.Mul().shard(data_parallel_strategy_one)
|
|
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
|
|
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
|
|
|
|
def list_group(self, world_rank, group_size):
|
|
if group_size > get_group_size():
|
|
raise ValueError(
|
|
"group size can not be greater than local rank size, group size is {}, "
|
|
"local_rank_size is {}".format(group_size, get_group_size())
|
|
)
|
|
if len(world_rank) % group_size != 0:
|
|
raise ValueError("please make your group size correct.")
|
|
world_rank_list = zip(*(iter(world_rank), ) * group_size)
|
|
group_list = [list(i) for i in world_rank_list]
|
|
return group_list
|
|
|
|
def _check_rank_ids(self, process_groups, rank_size):
|
|
seen = set()
|
|
for rid in itertools.chain(*process_groups):
|
|
validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups")
|
|
if rid in seen:
|
|
raise ValueError("rank id in process_groups should not be duplicated.")
|
|
seen.add(rid)
|
|
|
|
def construct(self, inputs):
|
|
x_shape = F.shape(inputs)
|
|
if len(x_shape) == 5:
|
|
inputs = self.reshape(inputs, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
|
|
|
|
flag = self.use_batch_statistics
|
|
|
|
if flag:
|
|
output = self.bn_train(inputs, self.gamma, self.beta, self.moving_mean, self.moving_variance)[0]
|
|
|
|
if len(x_shape) == 5:
|
|
output = self.reshape(output, x_shape)
|
|
return output
|
|
|
|
output = self.bn_infer(inputs, self.gamma, self.beta, self.moving_mean, self.moving_variance)[0]
|
|
if len(x_shape) == 5:
|
|
output = self.reshape(output, x_shape)
|
|
return output
|
|
|
|
def extend_repr(self):
|
|
return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
|
|
self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance
|
|
)
|
|
|
|
|
|
class GroupConv2D(Cell):
|
|
|
|
def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, groups):
|
|
super(GroupConv2D, self).__init__()
|
|
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
|
|
|
|
if self.data_format is 'NHWC':
|
|
self.ms_stride = strides[1]
|
|
self.ms_dilation = dilations[1]
|
|
|
|
elif self.data_format is 'NCHW':
|
|
self.ms_stride = strides[2]
|
|
self.ms_dilation = dilations[2]
|
|
|
|
self.conv2d = P.Conv2D(
|
|
out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,
|
|
dilation=self.ms_dilation, mode=1, group=groups, data_format=self.data_format
|
|
)
|
|
|
|
def construct(self, inputs, filters):
|
|
outputs = self.conv2d(inputs, filters)
|
|
return outputs
|
|
|
|
|
|
class SeparableConv1D(Cell):
|
|
|
|
def __init__(self, stride, padding, data_format, dilations, out_channel, k_size, in_channel, depth_multiplier):
|
|
super(SeparableConv1D, self).__init__()
|
|
self.data_format, self.padding = preprocess_1d_format(data_format, padding)
|
|
self.stride = (1, stride)
|
|
self.dilations = (1, dilations)
|
|
self.k_size = (1, k_size)
|
|
self.out_channel = out_channel
|
|
self.in_channel = in_channel
|
|
self.depth_multiplier = depth_multiplier
|
|
self.depthwise_conv = P.Conv2D(
|
|
out_channel=self.in_channel * self.depth_multiplier, kernel_size=self.k_size, pad_mode=self.padding,
|
|
stride=self.stride, dilation=self.dilations, mode=1, group=self.in_channel
|
|
)
|
|
|
|
self.pointwise_conv = P.Conv2D(
|
|
out_channel=self.out_channel, kernel_size=(1, 1), pad_mode=self.padding, stride=(1, 1), dilation=(1, 1),
|
|
mode=1, group=1
|
|
)
|
|
|
|
self.expand_dims = P.ExpandDims()
|
|
self.squeeze = P.Squeeze(2)
|
|
|
|
def construct(self, x, depthwise_filters, pointwise_filters):
|
|
|
|
if self.data_format == 'NWC':
|
|
x = nhwc_to_nchw(x)
|
|
|
|
x = self.expand_dims(x, 2)
|
|
depthwise_filters = self.expand_dims(depthwise_filters, 2)
|
|
pointwise_filters = self.expand_dims(pointwise_filters, 2)
|
|
|
|
outputs = self.depthwise_conv(x, depthwise_filters)
|
|
outputs = self.pointwise_conv(outputs, pointwise_filters)
|
|
|
|
outputs = self.squeeze(outputs)
|
|
|
|
if self.data_format == 'NWC':
|
|
outputs = nchw_to_nhwc(outputs)
|
|
return outputs
|
|
|
|
|
|
class SeparableConv2D(Cell):
|
|
|
|
def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, in_channel, depth_multiplier):
|
|
super(SeparableConv2D, self).__init__()
|
|
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
|
|
self.k_size = k_size
|
|
self.out_channel = out_channel
|
|
self.in_channel = in_channel
|
|
self.depth_multiplier = depth_multiplier
|
|
|
|
if self.data_format is 'NHWC':
|
|
self.ms_stride = strides[1]
|
|
self.ms_dilation = dilations[1]
|
|
elif self.data_format is 'NCHW':
|
|
self.ms_stride = strides[2]
|
|
self.ms_dilation = dilations[2]
|
|
|
|
self.depthwise_conv = P.Conv2D(
|
|
out_channel=self.in_channel * self.depth_multiplier, kernel_size=self.k_size, pad_mode=self.padding,
|
|
stride=self.ms_stride, dilation=self.ms_dilation, mode=1, group=self.in_channel,
|
|
data_format=self.data_format
|
|
)
|
|
|
|
self.pointwise_conv = P.Conv2D(
|
|
out_channel=self.out_channel, kernel_size=(1, 1), pad_mode=self.padding, stride=(1, 1), dilation=(1, 1),
|
|
mode=1, group=1, data_format=self.data_format
|
|
)
|
|
|
|
def construct(self, x, depthwise_filters, pointwise_filters):
|
|
outputs = self.depthwise_conv(x, depthwise_filters)
|
|
outputs = self.pointwise_conv(outputs, pointwise_filters)
|
|
return outputs
|
|
|
|
|
|
class AdaptiveMeanPool1D(Cell):
|
|
|
|
def __init__(self, output_size, data_format):
|
|
super(AdaptiveMeanPool1D, self).__init__()
|
|
self.data_format, _ = preprocess_1d_format(data_format, None)
|
|
self.output_size = output_size
|
|
if self.data_format == 'NWC':
|
|
self.data_format = 'NHWC'
|
|
self.h_axis = 1
|
|
else:
|
|
self.data_format = 'NCHW'
|
|
self.h_axis = 2
|
|
self.expand_dims = P.ExpandDims()
|
|
self.squeeze = P.Squeeze(self.h_axis)
|
|
self.shape = P.Shape()
|
|
|
|
def construct(self, inputs):
|
|
if self.data_format == 'NHWC':
|
|
n, w, c = self.shape(inputs)
|
|
else:
|
|
n, c, w = self.shape(inputs)
|
|
inputs = self.expand_dims(inputs, self.h_axis)
|
|
stride = (1, w // self.output_size)
|
|
kernel = (1, w - (self.output_size - 1) * stride[1])
|
|
outputs = P.AvgPool(kernel_size=kernel, strides=stride, pad_mode='VALID', data_format=self.data_format)(inputs)
|
|
outputs = self.squeeze(outputs)
|
|
|
|
return outputs
|
|
|
|
|
|
class AdaptiveMeanPool2D(Cell):
|
|
|
|
def __init__(self, output_size, data_format):
|
|
super(AdaptiveMeanPool2D, self).__init__()
|
|
self.data_format, _ = preprocess_2d_format(data_format, None)
|
|
self.output_size = output_size
|
|
if self.data_format == 'NHWC':
|
|
self.h_axis = 1
|
|
else:
|
|
self.h_axis = 2
|
|
self.shape = P.Shape()
|
|
|
|
def construct(self, inputs):
|
|
if self.data_format == 'NHWC':
|
|
n, h, w, c = self.shape(inputs)
|
|
else:
|
|
n, c, h, w = self.shape(inputs)
|
|
|
|
out_h, out_w = self.output_size
|
|
stride_h = h // out_h
|
|
kernel_h = h - (out_h - 1) * stride_h
|
|
stride_w = w // out_w
|
|
kernel_w = w - (out_w - 1) * stride_w
|
|
outputs = P.AvgPool(
|
|
kernel_size=(kernel_h, kernel_w), strides=(stride_h, stride_w), pad_mode='VALID',
|
|
data_format=self.data_format
|
|
)(inputs)
|
|
|
|
return outputs
|
|
|
|
|
|
class AdaptiveMeanPool3D(Cell):
|
|
|
|
def __init__(self, output_size, data_format):
|
|
pass
|
|
|
|
def __call__(self, inputs):
|
|
raise NotImplementedError
|
|
|
|
|
|
class AdaptiveMaxPool1D(Cell):
|
|
|
|
def __init__(self, output_size, data_format):
|
|
super(AdaptiveMaxPool1D, self).__init__()
|
|
self.data_format, _ = preprocess_1d_format(data_format, None)
|
|
self.output_size = output_size
|
|
if self.data_format == 'NWC':
|
|
self.data_format = 'NHWC'
|
|
self.h_axis = 1
|
|
else:
|
|
self.data_format = 'NCHW'
|
|
self.h_axis = 2
|
|
self.expand_dims = P.ExpandDims()
|
|
self.squeeze = P.Squeeze(self.h_axis)
|
|
self.shape = P.Shape()
|
|
|
|
def construct(self, inputs):
|
|
|
|
if self.data_format == 'NHWC':
|
|
n, w, c = self.shape(inputs)
|
|
else:
|
|
n, c, w = self.shape(inputs)
|
|
inputs = self.expand_dims(inputs, self.h_axis)
|
|
stride = (1, w // self.output_size)
|
|
kernel = (1, w - (self.output_size - 1) * stride[1])
|
|
outputs = P.MaxPool(kernel_size=kernel, strides=stride, pad_mode='VALID', data_format=self.data_format)(inputs)
|
|
outputs = self.squeeze(outputs)
|
|
|
|
return outputs
|
|
|
|
|
|
class AdaptiveMaxPool2D(Cell):
|
|
|
|
def __init__(self, output_size, data_format):
|
|
super(AdaptiveMaxPool2D, self).__init__()
|
|
self.data_format, _ = preprocess_2d_format(data_format, None)
|
|
self.output_size = output_size
|
|
if self.data_format == 'NHWC':
|
|
self.h_axis = 1
|
|
else:
|
|
self.h_axis = 2
|
|
self.shape = P.Shape()
|
|
|
|
def construct(self, inputs):
|
|
if self.data_format == 'NHWC':
|
|
n, h, w, c = self.shape(inputs)
|
|
else:
|
|
n, c, h, w = self.shape(inputs)
|
|
out_h, out_w = self.output_size
|
|
stride_h = h // out_h
|
|
kernel_h = h - (out_h - 1) * stride_h
|
|
stride_w = w // out_w
|
|
kernel_w = w - (out_w - 1) * stride_w
|
|
outputs = P.MaxPool(
|
|
kernel_size=(kernel_h, kernel_w), strides=(stride_h, stride_w), pad_mode='VALID',
|
|
data_format=self.data_format
|
|
)(inputs)
|
|
|
|
return outputs
|
|
|
|
|
|
class AdaptiveMaxPool3D(Cell):
|
|
|
|
def __init__(self, output_size, data_format):
|
|
pass
|
|
|
|
def __call__(self, inputs):
|
|
raise NotImplementedError
|
|
|
|
|
|
class BinaryConv2D(Cell):
|
|
|
|
def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, in_channel):
|
|
super(BinaryConv2D, self).__init__()
|
|
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
|
|
if self.data_format is 'NHWC':
|
|
self.ms_stride = strides[1]
|
|
self.ms_dilation = dilations[1]
|
|
elif self.data_format is 'NCHW':
|
|
self.ms_stride = strides[2]
|
|
self.ms_dilation = dilations[2]
|
|
|
|
self.conv2d = P.Conv2D(
|
|
out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,
|
|
dilation=self.ms_dilation, mode=1, group=1, data_format=self.data_format
|
|
)
|
|
|
|
@bprop_getters.register(P.Sign)
|
|
def get_bprop_Sign(self):
|
|
|
|
def bprop(x, out, dout):
|
|
|
|
grad = P.clip_by_value(dout, -1, 1)
|
|
return (grad, )
|
|
|
|
return bprop
|
|
|
|
self.sign = P.Sign()
|
|
|
|
def construct(self, inputs, filters):
|
|
|
|
filters = self.sign(filters)
|
|
outputs = self.conv2d(inputs, filters)
|
|
|
|
return outputs
|
|
|
|
|
|
class DorefaConv2D(Cell):
|
|
|
|
def __init__(self, bitW, bitA, strides, padding, data_format, dilations, out_channel, k_size, in_channel):
|
|
super(DorefaConv2D, self).__init__()
|
|
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
|
|
self.bitW = ms.Tensor(bitW)
|
|
self.bitA = ms.Tensor(bitA)
|
|
if self.data_format is 'NHWC':
|
|
self.ms_stride = strides[1]
|
|
self.ms_dilation = dilations[1]
|
|
# self.transpose = P.Transpose()
|
|
elif self.data_format is 'NCHW':
|
|
self.ms_stride = strides[2]
|
|
self.ms_dilation = dilations[2]
|
|
|
|
self.conv2d = P.Conv2D(
|
|
out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,
|
|
dilation=self.ms_dilation, mode=1, group=1
|
|
)
|
|
|
|
@bprop_getters.register(P.Round)
|
|
def get_bprop_Round(self):
|
|
|
|
def bprop(x, out, dout):
|
|
|
|
return (dout, )
|
|
|
|
return bprop
|
|
|
|
@bprop_getters.register(P.Sign)
|
|
def get_bprop_Sign(self):
|
|
|
|
def bprop(x, out, dout):
|
|
|
|
return (dout, )
|
|
|
|
return bprop
|
|
|
|
self.mimimum = P.Minimum()
|
|
self.abs = P.Abs()
|
|
self.round = P.Round()
|
|
self.reducemean = P.ReduceMean()
|
|
self.sign = P.Sign()
|
|
self.pow = P.Pow()
|
|
self.sub = P.Sub()
|
|
self.oneslike = P.OnesLike()
|
|
|
|
def cabs(self, inputs):
|
|
|
|
a = P.stop_gradient(self.oneslike(inputs))
|
|
return self.mimimum(self.abs(inputs), a)
|
|
|
|
def _quantize_dorefa(self, x, k):
|
|
|
|
n = self.sub(self.pow(2.0, k), 1)
|
|
return self.round(x * n) / n
|
|
|
|
def quantize_active(self, x, bitA):
|
|
if bitA == 32:
|
|
return x
|
|
return self._quantize_dorefa(x, bitA)
|
|
|
|
def quantize_weight(self, x, bitW, force_quantization=False):
|
|
|
|
if bitW == 32 and not force_quantization:
|
|
return x
|
|
|
|
if bitW == 1:
|
|
E = P.stop_gradient(self.reducemean(self.abs(x)))
|
|
return self.sign(x / E) * E
|
|
|
|
x = P.clip_by_value(x * 0.5 + 0.5, 0.0, 1.0)
|
|
|
|
return 2 * self._quantize_dorefa(x, bitW) - 1
|
|
|
|
def construct(self, inputs, filters):
|
|
|
|
if self.data_format == 'NHWC':
|
|
inputs = nhwc_to_nchw(inputs)
|
|
|
|
inputs = self.quantize_active(self.cabs(inputs), self.bitA)
|
|
|
|
filters = self.quantize_weight(filters, self.bitW)
|
|
|
|
outputs = self.conv2d(inputs, filters)
|
|
|
|
if self.data_format == 'NHWC':
|
|
outputs = nchw_to_nhwc(outputs)
|
|
|
|
return outputs
|
|
|
|
|
|
class rnncell(Cell):
|
|
|
|
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act):
|
|
super(rnncell, self).__init__()
|
|
self.weight_ih = weight_ih
|
|
self.weight_hh = weight_hh
|
|
self.bias_ih = bias_ih
|
|
self.bias_hh = bias_hh
|
|
self.act_fn = P.ReLU() if act == 'relu' else P.Tanh()
|
|
self.transpose = P.Transpose()
|
|
|
|
def construct(self, input, h):
|
|
self.weight_ih = self.transpose(self.weight_ih, (1, 0))
|
|
i2h = P.matmul(input, self.weight_ih)
|
|
if self.bias_ih is not None:
|
|
i2h += self.bias_ih
|
|
self.weight_hh = self.transpose(self.weight_hh, (1, 0))
|
|
h2h = P.matmul(h, self.weight_hh)
|
|
if self.bias_hh is not None:
|
|
h2h += self.bias_hh
|
|
h = self.act_fn(i2h + h2h)
|
|
return h, h
|
|
|
|
|
|
class lstmcell(Cell):
|
|
|
|
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
|
|
super(lstmcell, self).__init__()
|
|
self.weight_ih = weight_ih
|
|
self.weight_hh = weight_hh
|
|
self.bias_ih = bias_ih
|
|
self.bias_hh = bias_hh
|
|
self.gate_act_fn = P.Sigmoid()
|
|
self.act_fn = P.Tanh()
|
|
self.transpose = P.Transpose()
|
|
self.split = P.Split(axis=-1, output_num=4)
|
|
|
|
def construct(self, input, h, c):
|
|
|
|
self.weight_ih = self.transpose(self.weight_ih, (1, 0))
|
|
gates = P.matmul(input, self.weight_ih)
|
|
if self.bias_ih is not None:
|
|
gates += self.bias_ih
|
|
self.weight_hh = self.transpose(self.weight_hh, (1, 0))
|
|
gates += P.matmul(h, self.weight_hh)
|
|
if self.bias_hh is not None:
|
|
gates += self.bias_hh
|
|
|
|
gate_slices = self.split(gates)
|
|
i = self.gate_act_fn(gate_slices[0])
|
|
f = self.gate_act_fn(gate_slices[1])
|
|
o = self.gate_act_fn(gate_slices[3])
|
|
c = f * c + i * self.act_fn(gate_slices[2])
|
|
h = o * self.act_fn(c)
|
|
|
|
return h, h, c
|
|
|
|
|
|
class grucell(Cell):
|
|
|
|
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
|
|
super(grucell, self).__init__()
|
|
self.weight_ih = weight_ih
|
|
self.weight_hh = weight_hh
|
|
self.bias_ih = bias_ih
|
|
self.bias_hh = bias_hh
|
|
self.gate_act_fn = P.Sigmoid()
|
|
self.act_fn = P.Tanh()
|
|
self.transpose = P.Transpose()
|
|
self.split = P.Split(axis=-1, output_num=3)
|
|
|
|
def construct(self, input, h):
|
|
|
|
self.weight_ih = self.transpose(self.weight_ih, (1, 0))
|
|
x_gates = P.matmul(input, self.weight_ih)
|
|
if self.bias_ih is not None:
|
|
x_gates += self.bias_ih
|
|
self.weight_hh = self.transpose(self.weight_hh, (1, 0))
|
|
h_gates = P.matmul(h, self.weight_hh)
|
|
if self.bias_hh is not None:
|
|
h_gates += self.bias_hh
|
|
|
|
x_r, x_z, x_c = self.split(x_gates)
|
|
h_r, h_z, h_c = self.split(h_gates)
|
|
|
|
r = self.gate_act_fn(x_r + h_r)
|
|
z = self.gate_act_fn(x_r + h_z)
|
|
c = self.act_fn(x_c + r * h_c)
|
|
h = (h - c) * z + c
|
|
|
|
return h, h
|
|
|
|
|
|
class rnnbase(Cell):
|
|
|
|
def __init__(
|
|
self,
|
|
mode,
|
|
input_size,
|
|
hidden_size,
|
|
num_layers,
|
|
bias,
|
|
batch_first,
|
|
dropout,
|
|
bidirectional,
|
|
is_train,
|
|
):
|
|
super(rnnbase, self).__init__()
|
|
self.mode = mode
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.bidirect = 2 if bidirectional else 1
|
|
self.batch_first = batch_first
|
|
if mode == 'LSTM':
|
|
self.lstm = ms.nn.LSTM(
|
|
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=bias,
|
|
batch_first=batch_first, dropout=dropout, bidirectional=bidirectional
|
|
)
|
|
elif mode == 'GRU':
|
|
|
|
raise NotImplementedError
|
|
|
|
elif mode == 'RNN_TANH':
|
|
|
|
raise NotImplementedError
|
|
|
|
elif mode == 'RNN_RELU':
|
|
|
|
raise NotImplementedError
|
|
|
|
self.zeros = P.Zeros()
|
|
|
|
def construct(self, input, states):
|
|
input_shape = input.shape
|
|
input_dtype = input.dtype
|
|
if self.mode == 'LSTM':
|
|
if self.batch_first:
|
|
batch_size = input_shape[0]
|
|
else:
|
|
batch_size = input_shape[1]
|
|
if states is None:
|
|
h = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype)
|
|
c = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype)
|
|
states = (h, c)
|
|
output, (h, c) = self.lstm(input, states)
|
|
return output, (h, c)
|