JGAN/models/bicyclegan/bicyclegan.py

202 lines
8.2 KiB
Python

import argparse
import os
import numpy as np
import datetime
import time
import sys
from models import *
from datasets import *
import cv2
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="edges2shoes", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=8, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=128, help="size of image height")
parser.add_argument("--img_width", type=int, default=128, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--latent_dim", type=int, default=8, help="number of latent codes")
parser.add_argument("--sample_interval", type=int, default=500, help="interval between saving generator samples")
parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between model checkpoints")
parser.add_argument("--lambda_pixel", type=float, default=10, help="pixelwise loss weight")
parser.add_argument("--lambda_latent", type=float, default=0.5, help="latent loss weight")
parser.add_argument("--lambda_kl", type=float, default=0.01, help="kullback-leibler loss weight")
opt = parser.parse_args()
print(opt)
os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)
input_shape = (opt.channels, opt.img_height, opt.img_width)
# Loss functions
mae_loss = nn.L1Loss()
# Initialize generator, encoder and discriminators
generator = Generator(opt.latent_dim, input_shape)
encoder = Encoder(opt.latent_dim, input_shape)
D_VAE = MultiDiscriminator(input_shape)
D_LR = MultiDiscriminator(input_shape)
# Optimizers
optimizer_G = nn.Adam(encoder.parameters() + generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_VAE = nn.Adam(D_VAE.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_LR = nn.Adam(D_LR.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
dataloader = ImageDataset("../../data/%s" % opt.dataset_name, input_shape).set_attrs(batch_size=opt.batch_size, shuffle=False, num_workers=opt.n_cpu)
valdataloader = ImageDataset("../../data/%s" % opt.dataset_name, input_shape, mode="val").set_attrs(batch_size=8, shuffle=False, num_workers=1)
def reparameterization(mu, logvar):
std = jt.exp(logvar / 2)
sampled_z = jt.array(np.random.normal(0, 1, (mu.shape[0], opt.latent_dim))).float32()
z = sampled_z * std + mu
return z
def sample_images(batches_done):
"""Saves a generated sample from the validation set"""
generator.eval()
img_samples = None
imgs = next(iter(dataloader))
for idx in range(8):
img_A = jt.array(imgs[0][idx])
real_A = img_A.reindex([opt.latent_dim, *img_A.shape], ["i1", "i2", "i3"])
# Sample latent representations
sampled_z = jt.array(np.random.normal(0, 1, (opt.latent_dim, opt.latent_dim))).float32()
# Generate samples
fake_B = generator(real_A, sampled_z)
# Concatenate samples horisontally
fake_B_ = []
for i in range(fake_B.size(0)): fake_B_.append(fake_B.numpy()[i])
fake_B = np.concatenate(fake_B_, -1)
img_sample = np.concatenate((img_A.numpy(), fake_B), -1)[np.newaxis,:]
# Concatenate with previous samples vertically
img_samples = img_sample if img_samples is None else np.concatenate((img_samples, img_sample), -2)
min_, max_ = img_samples.min(), img_samples.max()
img_samples = (img_samples[0] - min_) / (max_ - min_) * 255.
img_samples = img_samples.transpose((1,2,0))
cv2.imwrite("images/%s/%s.png" % (opt.dataset_name, batches_done), img_samples)
generator.train()
# ----------
# Training
# ----------
# Adversarial loss
valid = 1
fake = 0
prev_time = time.time()
for epoch in range(opt.epoch, opt.n_epochs):
for i, batch in enumerate(dataloader):
# Set model input
real_A = batch[0].stop_grad()
real_B = batch[1].stop_grad()
# -------------------------------
# Train Generator and Encoder
# -------------------------------
# ----------
# cVAE-GAN
# ----------
# Produce output using encoding of B (cVAE-GAN)
mu, logvar = encoder(real_B)
encoded_z = reparameterization(mu, logvar)
fake_B = generator(real_A, encoded_z)
# Pixelwise loss of translated image by VAE
loss_pixel = mae_loss(fake_B, real_B)
# Kullback-Leibler divergence of encoded B
loss_kl = 0.5 * jt.sum(jt.exp(logvar) + mu.sqr() - logvar - 1)
# Adversarial loss
loss_VAE_GAN = D_VAE.compute_loss(fake_B, valid)
# ---------
# cLR-GAN
# ---------
# Produce output using sampled z (cLR-GAN)
sampled_z = jt.array(np.random.normal(0, 1, (real_A.shape[0], opt.latent_dim))).float32()
_fake_B = generator(real_A, sampled_z)
# cLR Loss: Adversarial loss
loss_LR_GAN = D_LR.compute_loss(_fake_B, valid)
# ----------------------------------
# Total Loss (Generator + Encoder)
# ----------------------------------
loss_GE = loss_VAE_GAN + loss_LR_GAN + opt.lambda_pixel * loss_pixel + opt.lambda_kl * loss_kl
# loss_GE.sync()
# optimizer_E.step(loss_GE)
# ---------------------
# Generator Only Loss
# ---------------------
# Latent L1 loss
_mu, _ = encoder(_fake_B)
loss_latent = opt.lambda_latent * mae_loss(_mu, sampled_z) + loss_GE
loss_latent.sync()
optimizer_G.step(loss_latent)
# ----------------------------------
# Train Discriminator (cVAE-GAN)
# ----------------------------------
loss_D_VAE = D_VAE.compute_loss(real_B, valid) + D_VAE.compute_loss(fake_B.stop_grad(), fake)
loss_D_VAE.sync()
optimizer_D_VAE.step(loss_D_VAE)
# ---------------------------------
# Train Discriminator (cLR-GAN)
# ---------------------------------
loss_D_LR = D_LR.compute_loss(real_B, valid) + D_LR.compute_loss(_fake_B.stop_grad(), fake)
loss_D_LR.sync()
optimizer_D_LR.step(loss_D_LR)
# --------------
# Log Progress
# --------------
# Determine approximate time left
batches_done = epoch * len(dataloader) + i
batches_left = opt.n_epochs * len(dataloader) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
prev_time = time.time()
if i % 10 == 0:
print(
"\r[Epoch %d/%d] [Batch %d/%d] [D VAE_loss: %f, LR_loss: %f] [G loss: %f, pixel: %f, kl: %f, latent: %f] ETA: %s"
% (
epoch,
opt.n_epochs,
i,
len(dataloader),
loss_D_VAE.data[0],
loss_D_LR.data[0],
loss_GE.data[0],
loss_pixel.data[0],
loss_kl.data[0],
loss_latent.data[0],
time_left,
)
)
if batches_done % opt.sample_interval == 0:
sample_images(batches_done)
if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
# Save model checkpoints
saved_name = "saved_models"
generator.save(os.path.join(f"{saved_name}/{opt.dataset_name}/generator_last.pkl"))
encoder.save(os.path.join(f"{saved_name}/{opt.dataset_name}/encoder_last.pkl"))
D_VAE.save(os.path.join(f"{saved_name}/{opt.dataset_name}/D_VAE_last.pkl"))
D_LR.save(os.path.join(f"{saved_name}/{opt.dataset_name}/D_LR_last.pkl"))