EMSim_plus/EM Prediction/generator.py

108 lines
4.7 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Conv2DTranspose, Concatenate
from tensorflow.keras import Model
class encoder(Model):
def __init__(self):
super(encoder, self).__init__()
self.conv1 = Conv2D(64, 3, activation='relu',padding='SAME') # ,kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
self.max1 = MaxPooling2D(2, padding='same')
self.conv2 = Conv2D(32, 3, activation='relu',padding='SAME') # ,kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
self.max2 = MaxPooling2D(2, padding='same')
self.conv3 = Conv2D(16, 5, activation='relu',padding='SAME') # ,kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
self.max3 = MaxPooling2D(2, padding='same')
# self.dense = Dense(128,activation='relu',kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
@tf.autograph.experimental.do_not_convert
def call(self, x, **kwargs):
x0 = self.conv1(x)
x1 = self.max1(x0)
x1 = self.conv2(x1)
x2 = self.max2(x1)
x2 = self.conv3(x2)
x3 = self.max3(x2)
return x0, x1, x2, x3
class decoder(Model):
def __init__(self):
super(decoder, self).__init__()
self.conv0 = Conv2DTranspose(16, 7, activation='relu',padding='SAME')#,kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
self.max1 = UpSampling2D(2)
self.conv1 = Conv2DTranspose(32, 7, activation='relu',padding='SAME')#,kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
self.max2 = UpSampling2D(2)
self.conv2 = Conv2DTranspose(64, 3, activation='relu',padding='SAME')#,kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
self.max3 = UpSampling2D(2)
self.conv3 = Conv2DTranspose(1, 3, activation='relu',padding='SAME')#,kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
@tf.autograph.experimental.do_not_convert
def call(self, vals, **kwargs):
x1 = self.conv0(vals[3])
x1 = self.max1(x1)
x1_shape = tf.shape(vals[2])
x1 = tf.slice(x1, tf.zeros(x1_shape.shape, dtype=tf.dtypes.int32), x1_shape)
x1 = Concatenate()([x1, vals[2]])
x2 = self.conv1(x1)
x2 = self.max2(x2)
x2_shape = tf.shape(vals[1])
x2 = tf.slice(x2, [0, 0, 0, 0], x2_shape)
x2 = Concatenate()([x2, vals[1]])
x3 = self.conv2(x2)
x3 = self.max3(x3)
x3_shape = tf.shape(vals[0])
x3 = tf.slice(x3, [0, 0, 0, 0], x3_shape)
x3 = Concatenate()([x3, vals[0]])
x4 = self.conv3(x3)
return x4
class ls_layer(Model):
def __init__(self):
super(ls_layer, self).__init__()
self.fl = tf.keras.layers.Flatten()
self.fc1 = tf.keras.layers.Dense(256, activation='relu', use_bias=True )#, kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
self.fc2 = tf.keras.layers.Dense(128, activation='relu', use_bias=True )#, kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
self.fc3 = tf.keras.layers.Dense(256, activation='relu', use_bias=True )#, kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
self.fc4 = tf.keras.layers.Dense(576, activation='relu', use_bias=True )#, kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
self.t_fc1 = tf.keras.layers.Dense(64, activation='relu', use_bias=True )#, kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
self.t_fc2 = tf.keras.layers.Dense(64, activation='relu', use_bias=True )#, kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
self.t_fc3 = tf.keras.layers.Dense(64, activation='relu', use_bias=True )#, kernel_regularizer=l2(reg_rate), bias_regularizer=l2(reg_rate))
@tf.autograph.experimental.do_not_convert
def call(self, vals, **kwargs):
x = vals[0]
t = vals[1]
x = self.fl(x)
x = self.fc1(x)
x = self.fc2(x)
t = self.t_fc1(t)
t = self.t_fc2(t)
t = self.t_fc3(t)
x2 = Concatenate()([x,t])
x2 = self.fc3(x2)
x2 = self.fc4(x2)
return x2
class Generator(Model):
def __init__(self):
super(Generator, self).__init__()
self.ae = encoder()
self.de = decoder()
self.ls = ls_layer()
@tf.autograph.experimental.do_not_convert
def call(self, vals, **kwargs):
x = vals[0]
t = vals[1]
ae = self.ae(x)
ls = self.ls((ae[3], t))
ls = tf.reshape(ls, [-1, 6, 6, 16])
de = self.de(ae[0:3] + (ls,))
return de