add download

This commit is contained in:
zwy 2020-05-13 11:53:48 +08:00
parent 8dbe6000ed
commit 530eeceb40
5 changed files with 80 additions and 34 deletions

View File

@ -0,0 +1,22 @@
#!/bin/bash
FILE=$1
if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
exit 1
fi
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
ZIP_FILE=./$FILE.zip
TARGET_DIR=./$FILE
wget -N $URL -O $ZIP_FILE
unzip $ZIP_FILE -d .
rm $ZIP_FILE
# Adapt to project expected directory heriarchy
mkdir -p "$TARGET_DIR/train" "$TARGET_DIR/test"
mv "$TARGET_DIR/trainA" "$TARGET_DIR/train/A"
mv "$TARGET_DIR/trainB" "$TARGET_DIR/train/B"
mv "$TARGET_DIR/testA" "$TARGET_DIR/test/A"
mv "$TARGET_DIR/testB" "$TARGET_DIR/test/B"

View File

@ -0,0 +1,8 @@
FILE=$1
URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz
TAR_FILE=./$FILE.tar.gz
TARGET_DIR=./$FILE/
wget -N $URL -O $TAR_FILE
mkdir $TARGET_DIR
tar -zxvf $TAR_FILE -C ./
rm $TAR_FILE

View File

@ -154,7 +154,7 @@ for epoch in range(opt.n_epochs):
optimizer_D.step(d_loss) optimizer_D.step(d_loss)
jt.sync_all(True) jt.sync_all()
if i % 50 == 0: if i % 50 == 0:
print( print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [Time: %f]" "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [Time: %f]"

View File

@ -26,8 +26,8 @@ parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay") parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=256, help="size of image height") parser.add_argument("--img_height", type=int, default=64, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width") parser.add_argument("--img_width", type=int, default=64, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels") parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs") parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs")
parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between saving model checkpoints") parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between saving model checkpoints")
@ -76,7 +76,7 @@ transform_ = [
# Training data loader # Training data loader
dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transform_=transform_, unaligned=True).set_attrs(batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu) dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transform_=transform_, unaligned=True).set_attrs(batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
val_dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transform_=transform_, unaligned=True, mode="test").set_attrs(batch_size=1, shuffle=True, num_workers=1) val_dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transform_=transform_, unaligned=True, mode="test").set_attrs(batch_size=5, shuffle=True, num_workers=1)
def sample_images(batches_done): def sample_images(batches_done):
"""Saves a generated sample from the test set""" """Saves a generated sample from the test set"""
@ -85,9 +85,9 @@ def sample_images(batches_done):
G_BA.eval() G_BA.eval()
real_A = imgs[0] real_A = imgs[0].stop_grad()
fake_B = G_AB(real_A) fake_B = G_AB(real_A)
real_B = imgs[1] real_B = imgs[1].stop_grad()
fake_A = G_BA(real_B) fake_A = G_BA(real_B)
# Arange images along x-axis # Arange images along x-axis
real_A = make_grid(torch.Tensor(real_A.numpy()), nrow=5, normalize=True) real_A = make_grid(torch.Tensor(real_A.numpy()), nrow=5, normalize=True)
@ -95,7 +95,7 @@ def sample_images(batches_done):
fake_A = make_grid(torch.Tensor(fake_A.numpy()), nrow=5, normalize=True) fake_A = make_grid(torch.Tensor(fake_A.numpy()), nrow=5, normalize=True)
fake_B = make_grid(torch.Tensor(fake_B.numpy()), nrow=5, normalize=True) fake_B = make_grid(torch.Tensor(fake_B.numpy()), nrow=5, normalize=True)
# Arange images along y-axis # Arange images along y-axis
image_grid = cat((real_A, fake_B, real_B, fake_A), 1) image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False) save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)
@ -157,7 +157,7 @@ for epoch in range(opt.epoch, opt.n_epochs):
loss_real = criterion_GAN(D_A(real_A), valid) loss_real = criterion_GAN(D_A(real_A), valid)
# Fake loss (on batch of previously generated samples) # Fake loss (on batch of previously generated samples)
fake_A_ = fake_A_buffer.push_and_pop(fake_A) fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake) loss_fake = criterion_GAN(D_A(fake_A_.stop_grad()), fake)
# Total loss # Total loss
loss_D_A = (loss_real + loss_fake) / 2 loss_D_A = (loss_real + loss_fake) / 2
@ -171,7 +171,7 @@ for epoch in range(opt.epoch, opt.n_epochs):
loss_real = criterion_GAN(D_B(real_B), valid) loss_real = criterion_GAN(D_B(real_B), valid)
# Fake loss (on batch of previously generated samples) # Fake loss (on batch of previously generated samples)
fake_B_ = fake_B_buffer.push_and_pop(fake_B) fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake) loss_fake = criterion_GAN(D_B(fake_B_.stop_grad()), fake)
# Total loss # Total loss
loss_D_B = (loss_real + loss_fake) / 2 loss_D_B = (loss_real + loss_fake) / 2

View File

@ -11,6 +11,9 @@ from models import *
from datasets import * from datasets import *
import jittor as jt import jittor as jt
from torchvision.utils import save_image
import torch
jt.flags.use_cuda = 1 jt.flags.use_cuda = 1
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -23,8 +26,8 @@ parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay") parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=256, help="size of image height") parser.add_argument("--img_height", type=int, default=128, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width") parser.add_argument("--img_width", type=int, default=128, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels") parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator samples") parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator samples")
parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between saving model checkpoints") parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between saving model checkpoints")
@ -83,32 +86,41 @@ transform_ = [
dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transforms_=transform_, unaligned=True).set_attrs(batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu) dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transforms_=transform_, unaligned=True).set_attrs(batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
val_dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transforms_=transform_, unaligned=True, mode="test").set_attrs(batch_size=1, shuffle=True, num_workers=1) val_dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transforms_=transform_, unaligned=True, mode="test").set_attrs(batch_size=5, shuffle=True, num_workers=1)
import cv2 # import cv2
def save_image(img, path, nrow=10): from pdb import set_trace as st
N,C,W,H = img.shape # def save_image(img, path, nrow=10):
img2 = img.reshape([-1,W*nrow*nrow,H]) # N,C,H,W = img.shape
img = img2[:,:W*nrow,:] # if N > nrow * nrow:
for i in range(1,nrow): # img = img[:nrow*nrow,:]
img = np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2) # elif N < nrow * nrow:
min_ = img.min() # img = np.concatenate([img, np.zeros((nrow*nrow-N,C,H,W))],axis=0)
max_ = img.max() # img2 = img.reshape([-1,W*nrow*nrow,H])
img = (img - min_) / (max_ - min_) * 255 # img = img2[:,:W*nrow,:]
img = img.transpose((1,2,0)) # for i in range(1,nrow):
cv2.imwrite(path,img) # img = np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2)
# min_ = img.min()
# max_ = img.max()
# img = (img - min_) / (max_ - min_) * 255
# img = img.transpose((1,2,0))
# cv2.imwrite(path,img)
def sample_images(batches_done): def sample_images(batches_done):
"""Saves a generated sample from the test set""" """Saves a generated sample from the test set"""
imgs = next(iter(val_dataloader)) imgs = next(iter(val_dataloader))
X1 = imgs[0] X1 = imgs[0].stop_grad()
X2 = imgs[1] X2 = imgs[1].stop_grad()
E1.eval()
E2.eval()
G1.eval()
G2.eval()
_, Z1 = E1(X1) _, Z1 = E1(X1)
_, Z2 = E2(X2) _, Z2 = E2(X2)
fake_X1 = G1(Z2) fake_X1 = G1(Z2)
fake_X2 = G2(Z1) fake_X2 = G2(Z1)
img_sample = jt.contrib.concat((X1, fake_X2, X2, fake_X1), 0) img_sample = jt.contrib.concat((X1, fake_X2, X2, fake_X1), 0)
save_image(img_sample.numpy(), "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5) save_image(torch.Tensor(img_sample.numpy()), "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)
def compute_kl(mu): def compute_kl(mu):
@ -123,18 +135,22 @@ def compute_kl(mu):
prev_time = time.time() prev_time = time.time()
for epoch in range(opt.epoch, opt.n_epochs): for epoch in range(opt.epoch, opt.n_epochs):
for i, batch in enumerate(dataloader): for i, batch in enumerate(dataloader):
jt.sync_all(True)
# Set model input # Set model input
X1 = batch[0] X1 = batch[0].stop_grad()
X2 = batch[1] X2 = batch[1].stop_grad()
# Adversarial ground truths # Adversarial ground truths
valid = jt.array(np.ones((X1.size(0), *D1.output_shape))).float32().stop_grad() valid = jt.ones((X1.size(0), *D1.output_shape)).stop_grad()
fake = jt.array(np.zeros((X1.size(0), *D1.output_shape))).float32().stop_grad() fake = jt.zeros((X1.size(0), *D1.output_shape)).stop_grad()
# ------------------------------- # -------------------------------
# Train Encoders and Generators # Train Encoders and Generators
# ------------------------------- # -------------------------------
E1.train()
E2.train()
G1.train()
G2.train()
# Get shared latent representation # Get shared latent representation
mu1, Z1 = E1(X1) mu1, Z1 = E1(X1)
mu2, Z2 = E2(X2) mu2, Z2 = E2(X2)
@ -218,8 +234,8 @@ for epoch in range(opt.epoch, opt.n_epochs):
) )
# If at sample interval save image # If at sample interval save image
# if batches_done % opt.sample_interval == 0: if batches_done % opt.sample_interval == 0:
# sample_images(batches_done) sample_images(batches_done)
if epoch >= opt.decay_epoch: if epoch >= opt.decay_epoch:
optimizer_G.lr = opt.lr * (opt.n_epochs - epoch - 1) / (opt.n_epochs - opt.decay_epoch) optimizer_G.lr = opt.lr * (opt.n_epochs - epoch - 1) / (opt.n_epochs - opt.decay_epoch)