update
This commit is contained in:
parent
f7d9a8ca98
commit
10cda5e900
|
@ -2,7 +2,6 @@ import argparse
|
|||
import os
|
||||
import numpy as np
|
||||
import time
|
||||
import itertools
|
||||
import cv2
|
||||
|
||||
import jittor as jt
|
||||
|
@ -90,7 +89,7 @@ train_loader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.b
|
|||
|
||||
# Optimizers
|
||||
optimizer_G = nn.Adam(
|
||||
itertools.chain(encoder.parameters(), decoder.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
|
||||
encoder.parameters() + decoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)
|
||||
)
|
||||
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import os
|
|||
import numpy as np
|
||||
import math
|
||||
from jittor import nn
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
|
||||
|
|
|
@ -2,7 +2,6 @@ import argparse
|
|||
import os
|
||||
import numpy as np
|
||||
import math
|
||||
import itertools
|
||||
import scipy
|
||||
import sys
|
||||
import time
|
||||
|
@ -57,7 +56,7 @@ val_dataloader = ImageDataset("../data/%s" % opt.dataset_name, img_shape, mode="
|
|||
|
||||
# Optimizers
|
||||
optimizer_G = nn.Adam(
|
||||
itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
|
||||
G_AB.parameters() + G_BA.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)
|
||||
)
|
||||
optimizer_D_A = nn.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
|
||||
optimizer_D_B = nn.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
|
||||
|
|
|
@ -2,7 +2,6 @@ import argparse
|
|||
import os
|
||||
import numpy as np
|
||||
import math
|
||||
import itertools
|
||||
|
||||
os.makedirs("images/static/", exist_ok=True)
|
||||
os.makedirs("images/varying_c1/", exist_ok=True)
|
||||
|
@ -28,6 +27,8 @@ import jittor as jt
|
|||
from jittor import init
|
||||
from jittor import nn
|
||||
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
def weights_init_normal(m):
|
||||
classname = m.__class__.__name__
|
||||
if (classname.find('Conv') != (- 1)):
|
||||
|
@ -119,9 +120,7 @@ dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.bat
|
|||
# Optimizers
|
||||
optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
|
||||
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
|
||||
optimizer_info = nn.Adam(
|
||||
itertools.chain(generator.parameters(), discriminator.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
|
||||
)
|
||||
optimizer_info = nn.Adam(generator.parameters() + discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
|
||||
|
||||
# Static generator inputs for sampling
|
||||
static_z = jt.array(np.zeros((opt.n_classes ** 2, opt.latent_dim))).float32()
|
||||
|
@ -217,7 +216,7 @@ for epoch in range(opt.n_epochs):
|
|||
d_real_loss = adversarial_loss(real_pred, valid)
|
||||
|
||||
# Loss for fake images
|
||||
fake_pred, _, _ = discriminator(gen_imgs.detach())
|
||||
fake_pred, _, _ = discriminator(gen_imgs.stop_grad())
|
||||
d_fake_loss = adversarial_loss(fake_pred, fake)
|
||||
|
||||
# Total discriminator loss
|
||||
|
|
|
@ -2,7 +2,6 @@ import argparse
|
|||
import os
|
||||
import numpy as np
|
||||
import math
|
||||
import itertools
|
||||
import mnistm
|
||||
from jittor.dataset.mnist import MNIST
|
||||
import jittor.transform as transform
|
||||
|
@ -167,7 +166,7 @@ dataloader_B = mnistm.MNISTM(mnist_root = "../../data/mnistm", train=True, trans
|
|||
|
||||
# Optimizers
|
||||
optimizer_G = nn.Adam(
|
||||
itertools.chain(generator.parameters(), classifier.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
|
||||
generator.parameters() + classifier.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)
|
||||
)
|
||||
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
|
||||
|
||||
|
|
|
@ -86,7 +86,7 @@ def save_image(img, path, nrow=10, padding=5):
|
|||
optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
|
||||
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
|
||||
|
||||
dataloader = ImageDataset("../data/%s" % opt.dataset_name, hr_shape=hr_shape).set_attrs(batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
|
||||
dataloader = ImageDataset("../../../jittor-GAN/data/%s" % opt.dataset_name, hr_shape=hr_shape).set_attrs(batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
|
||||
# ----------
|
||||
# Training
|
||||
# ----------
|
||||
|
|
Loading…
Reference in New Issue