ADD file via upload
This commit is contained in:
parent
729c8bd89e
commit
822825d5c7
|
@ -0,0 +1,30 @@
|
|||
import tensorflow as tf
|
||||
import tensorlayer as tl
|
||||
from tensorlayer.layers import Module
|
||||
|
||||
class RoiPoolingConv(Module):
|
||||
def __init__(self,pool_size,**kwargs):
|
||||
self.pool_size = pool_size
|
||||
super(RoiPoolingConv,self).__init__(**kwargs)
|
||||
def build(self,input_shape):
|
||||
self.nb_channels = input_shape[0][3]
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
input_shape2 = input_shape[1]
|
||||
return None, input_shape2[1], self.pool_size, self.pool_size, self.nb_channels
|
||||
|
||||
def call(self, x, mask=None):
|
||||
assert (len(x) == 2)
|
||||
img = x[0]
|
||||
rois = x[1]
|
||||
num_rois = tf.shape(rois)[1]
|
||||
batch_size = tf.shape(rois)[0]
|
||||
|
||||
box_index = tf.expand_dims(tf.range(0, batch_size), 1)
|
||||
box_index = tf.tile(box_index, (1, num_rois))
|
||||
box_index = tf.reshape(box_index, [-1])
|
||||
|
||||
rs = tf.image.crop_and_resize(img, tf.reshape(rois, [-1, 4]), box_index, (self.pool_size, self.pool_size))
|
||||
|
||||
final_output = tf.reshape(rs, (batch_size, num_rois, self.pool_size, self.pool_size, self.nb_channels))
|
||||
return final_output
|
Loading…
Reference in New Issue