JGAN/models/esrgan/models.py

117 lines
4.6 KiB
Python

import jittor as jt
from jittor import init
from jittor import nn
from jittor.models import vgg19
import math
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
init.gauss_(m.weight, 0.0, 0.02)
init.gauss_(m.bias, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
init.gauss_(m.weight, 1.0, 0.02)
init.constant_(m.bias, 0.0)
class FeatureExtractor(nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
vgg19_model = vgg19()
self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])
for m in self.modules():
weights_init_normal(m)
def execute(self, img):
return self.vgg19_54(img)
class DenseResidualBlock(nn.Module):
def __init__(self, filters, res_scale=0.2):
super(DenseResidualBlock, self).__init__()
self.res_scale = res_scale
def block(in_features, non_linearity=True):
layers = [nn.Conv(in_features, filters, 3, stride=1, padding=1, bias=True)]
if non_linearity:
layers += [nn.LeakyReLU()]
return nn.Sequential(*layers)
self.b1 = block(in_features=(1 * filters))
self.b2 = block(in_features=(2 * filters))
self.b3 = block(in_features=(3 * filters))
self.b4 = block(in_features=(4 * filters))
self.b5 = block(in_features=(5 * filters), non_linearity=False)
self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]
def execute(self, x):
inputs = x
for block in self.blocks:
out = block(inputs)
inputs = jt.contrib.concat([inputs, out], dim=1)
return (out * self.res_scale + x)
class ResidualInResidualDenseBlock(nn.Module):
def __init__(self, filters, res_scale=0.2):
super(ResidualInResidualDenseBlock, self).__init__()
self.res_scale = res_scale
self.dense_blocks = nn.Sequential(DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters))
def execute(self, x):
return (self.dense_blocks(x) * self.res_scale + x)
class GeneratorRRDB(nn.Module):
def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2):
super(GeneratorRRDB, self).__init__()
self.conv1 = nn.Conv(channels, filters, 3, stride=1, padding=1)
self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
self.conv2 = nn.Conv(filters, filters, 3, stride=1, padding=1)
upsample_layers = []
for _ in range(num_upsample):
upsample_layers += [nn.Conv(filters, (filters * 4), 3, stride=1, padding=1), nn.LeakyReLU(), nn.PixelShuffle(upscale_factor=2)]
self.upsampling = nn.Sequential(*upsample_layers)
self.conv3 = nn.Sequential(nn.Conv(filters, filters, 3, stride=1, padding=1), nn.LeakyReLU(), nn.Conv(filters, channels, 3, stride=1, padding=1))
for m in self.modules():
weights_init_normal(m)
def execute(self, x):
out1 = self.conv1(x)
out = self.res_blocks(out1)
out2 = self.conv2(out)
out = out1 + out2
out = self.upsampling(out)
out = self.conv3(out)
return out
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
self.input_shape = input_shape
(in_channels, in_height, in_width) = self.input_shape
(patch_h, patch_w) = (int((in_height / (2 ** 4))), int((in_width / (2 ** 4))))
self.output_shape = (1, patch_h, patch_w)
def discriminator_block(in_filters, out_filters, first_block=False):
layers = []
layers.append(nn.Conv(in_filters, out_filters, 3, stride=1, padding=1))
if (not first_block):
layers.append(nn.BatchNorm(out_filters))
layers.append(nn.LeakyReLU(scale=0.2))
layers.append(nn.Conv(out_filters, out_filters, 3, stride=2, padding=1))
layers.append(nn.BatchNorm(out_filters))
layers.append(nn.LeakyReLU(scale=0.2))
return layers
layers = []
in_filters = in_channels
for (i, out_filters) in enumerate([64, 128, 256, 512]):
layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
in_filters = out_filters
layers.append(nn.Conv(out_filters, 1, 3, stride=1, padding=1))
self.model = nn.Sequential(*layers)
for m in self.modules():
weights_init_normal(m)
def execute(self, img):
return self.model(img)