This commit is contained in:
zwy 2020-05-14 21:21:36 +08:00
parent f7d9a8ca98
commit 10cda5e900
6 changed files with 10 additions and 13 deletions

View File

@ -2,7 +2,6 @@ import argparse
import os
import numpy as np
import time
import itertools
import cv2
import jittor as jt
@ -90,7 +89,7 @@ train_loader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.b
# Optimizers
optimizer_G = nn.Adam(
itertools.chain(encoder.parameters(), decoder.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
encoder.parameters() + decoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

View File

@ -6,6 +6,7 @@ import os
import numpy as np
import math
from jittor import nn
jt.flags.use_cuda = 1
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')

View File

@ -2,7 +2,6 @@ import argparse
import os
import numpy as np
import math
import itertools
import scipy
import sys
import time
@ -57,7 +56,7 @@ val_dataloader = ImageDataset("../data/%s" % opt.dataset_name, img_shape, mode="
# Optimizers
optimizer_G = nn.Adam(
itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
G_AB.parameters() + G_BA.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = nn.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = nn.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

View File

@ -2,7 +2,6 @@ import argparse
import os
import numpy as np
import math
import itertools
os.makedirs("images/static/", exist_ok=True)
os.makedirs("images/varying_c1/", exist_ok=True)
@ -28,6 +27,8 @@ import jittor as jt
from jittor import init
from jittor import nn
jt.flags.use_cuda = 1
def weights_init_normal(m):
classname = m.__class__.__name__
if (classname.find('Conv') != (- 1)):
@ -119,9 +120,7 @@ dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.bat
# Optimizers
optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_info = nn.Adam(
itertools.chain(generator.parameters(), discriminator.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_info = nn.Adam(generator.parameters() + discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# Static generator inputs for sampling
static_z = jt.array(np.zeros((opt.n_classes ** 2, opt.latent_dim))).float32()
@ -217,8 +216,8 @@ for epoch in range(opt.n_epochs):
d_real_loss = adversarial_loss(real_pred, valid)
# Loss for fake images
fake_pred, _, _ = discriminator(gen_imgs.detach())
d_fake_loss = adversarial_loss(fake_pred, fake)
fake_pred, _, _ = discriminator(gen_imgs.stop_grad())
d_fake_loss = adversarial_loss(fake_pred, fake)
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2

View File

@ -2,7 +2,6 @@ import argparse
import os
import numpy as np
import math
import itertools
import mnistm
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
@ -167,7 +166,7 @@ dataloader_B = mnistm.MNISTM(mnist_root = "../../data/mnistm", train=True, trans
# Optimizers
optimizer_G = nn.Adam(
itertools.chain(generator.parameters(), classifier.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
generator.parameters() + classifier.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

View File

@ -86,7 +86,7 @@ def save_image(img, path, nrow=10, padding=5):
optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
dataloader = ImageDataset("../data/%s" % opt.dataset_name, hr_shape=hr_shape).set_attrs(batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
dataloader = ImageDataset("../../../jittor-GAN/data/%s" % opt.dataset_name, hr_shape=hr_shape).set_attrs(batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
# ----------
# Training
# ----------