add download
This commit is contained in:
parent
8dbe6000ed
commit
530eeceb40
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
|
||||
FILE=$1
|
||||
|
||||
if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
|
||||
echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
|
||||
ZIP_FILE=./$FILE.zip
|
||||
TARGET_DIR=./$FILE
|
||||
wget -N $URL -O $ZIP_FILE
|
||||
unzip $ZIP_FILE -d .
|
||||
rm $ZIP_FILE
|
||||
|
||||
# Adapt to project expected directory heriarchy
|
||||
mkdir -p "$TARGET_DIR/train" "$TARGET_DIR/test"
|
||||
mv "$TARGET_DIR/trainA" "$TARGET_DIR/train/A"
|
||||
mv "$TARGET_DIR/trainB" "$TARGET_DIR/train/B"
|
||||
mv "$TARGET_DIR/testA" "$TARGET_DIR/test/A"
|
||||
mv "$TARGET_DIR/testB" "$TARGET_DIR/test/B"
|
|
@ -0,0 +1,8 @@
|
|||
FILE=$1
|
||||
URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz
|
||||
TAR_FILE=./$FILE.tar.gz
|
||||
TARGET_DIR=./$FILE/
|
||||
wget -N $URL -O $TAR_FILE
|
||||
mkdir $TARGET_DIR
|
||||
tar -zxvf $TAR_FILE -C ./
|
||||
rm $TAR_FILE
|
|
@ -154,7 +154,7 @@ for epoch in range(opt.n_epochs):
|
|||
|
||||
optimizer_D.step(d_loss)
|
||||
|
||||
jt.sync_all(True)
|
||||
jt.sync_all()
|
||||
if i % 50 == 0:
|
||||
print(
|
||||
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [Time: %f]"
|
||||
|
|
|
@ -26,8 +26,8 @@ parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first
|
|||
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
|
||||
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
|
||||
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
|
||||
parser.add_argument("--img_height", type=int, default=256, help="size of image height")
|
||||
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
|
||||
parser.add_argument("--img_height", type=int, default=64, help="size of image height")
|
||||
parser.add_argument("--img_width", type=int, default=64, help="size of image width")
|
||||
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
|
||||
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs")
|
||||
parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between saving model checkpoints")
|
||||
|
@ -76,7 +76,7 @@ transform_ = [
|
|||
# Training data loader
|
||||
dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transform_=transform_, unaligned=True).set_attrs(batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
|
||||
|
||||
val_dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transform_=transform_, unaligned=True, mode="test").set_attrs(batch_size=1, shuffle=True, num_workers=1)
|
||||
val_dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transform_=transform_, unaligned=True, mode="test").set_attrs(batch_size=5, shuffle=True, num_workers=1)
|
||||
|
||||
def sample_images(batches_done):
|
||||
"""Saves a generated sample from the test set"""
|
||||
|
@ -85,9 +85,9 @@ def sample_images(batches_done):
|
|||
G_BA.eval()
|
||||
|
||||
|
||||
real_A = imgs[0]
|
||||
real_A = imgs[0].stop_grad()
|
||||
fake_B = G_AB(real_A)
|
||||
real_B = imgs[1]
|
||||
real_B = imgs[1].stop_grad()
|
||||
fake_A = G_BA(real_B)
|
||||
# Arange images along x-axis
|
||||
real_A = make_grid(torch.Tensor(real_A.numpy()), nrow=5, normalize=True)
|
||||
|
@ -95,7 +95,7 @@ def sample_images(batches_done):
|
|||
fake_A = make_grid(torch.Tensor(fake_A.numpy()), nrow=5, normalize=True)
|
||||
fake_B = make_grid(torch.Tensor(fake_B.numpy()), nrow=5, normalize=True)
|
||||
# Arange images along y-axis
|
||||
image_grid = cat((real_A, fake_B, real_B, fake_A), 1)
|
||||
image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
|
||||
save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False)
|
||||
|
||||
|
||||
|
@ -157,7 +157,7 @@ for epoch in range(opt.epoch, opt.n_epochs):
|
|||
loss_real = criterion_GAN(D_A(real_A), valid)
|
||||
# Fake loss (on batch of previously generated samples)
|
||||
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
|
||||
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
|
||||
loss_fake = criterion_GAN(D_A(fake_A_.stop_grad()), fake)
|
||||
# Total loss
|
||||
loss_D_A = (loss_real + loss_fake) / 2
|
||||
|
||||
|
@ -171,7 +171,7 @@ for epoch in range(opt.epoch, opt.n_epochs):
|
|||
loss_real = criterion_GAN(D_B(real_B), valid)
|
||||
# Fake loss (on batch of previously generated samples)
|
||||
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
|
||||
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
|
||||
loss_fake = criterion_GAN(D_B(fake_B_.stop_grad()), fake)
|
||||
# Total loss
|
||||
loss_D_B = (loss_real + loss_fake) / 2
|
||||
|
||||
|
|
|
@ -11,6 +11,9 @@ from models import *
|
|||
from datasets import *
|
||||
import jittor as jt
|
||||
|
||||
from torchvision.utils import save_image
|
||||
import torch
|
||||
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -23,8 +26,8 @@ parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first
|
|||
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
|
||||
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
|
||||
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
|
||||
parser.add_argument("--img_height", type=int, default=256, help="size of image height")
|
||||
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
|
||||
parser.add_argument("--img_height", type=int, default=128, help="size of image height")
|
||||
parser.add_argument("--img_width", type=int, default=128, help="size of image width")
|
||||
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
|
||||
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator samples")
|
||||
parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between saving model checkpoints")
|
||||
|
@ -83,32 +86,41 @@ transform_ = [
|
|||
|
||||
dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transforms_=transform_, unaligned=True).set_attrs(batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
|
||||
|
||||
val_dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transforms_=transform_, unaligned=True, mode="test").set_attrs(batch_size=1, shuffle=True, num_workers=1)
|
||||
val_dataloader = ImageDataset("../../../PyTorch-GAN/data/%s" % opt.dataset_name, transforms_=transform_, unaligned=True, mode="test").set_attrs(batch_size=5, shuffle=True, num_workers=1)
|
||||
|
||||
import cv2
|
||||
def save_image(img, path, nrow=10):
|
||||
N,C,W,H = img.shape
|
||||
img2 = img.reshape([-1,W*nrow*nrow,H])
|
||||
img = img2[:,:W*nrow,:]
|
||||
for i in range(1,nrow):
|
||||
img = np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2)
|
||||
min_ = img.min()
|
||||
max_ = img.max()
|
||||
img = (img - min_) / (max_ - min_) * 255
|
||||
img = img.transpose((1,2,0))
|
||||
cv2.imwrite(path,img)
|
||||
# import cv2
|
||||
from pdb import set_trace as st
|
||||
# def save_image(img, path, nrow=10):
|
||||
# N,C,H,W = img.shape
|
||||
# if N > nrow * nrow:
|
||||
# img = img[:nrow*nrow,:]
|
||||
# elif N < nrow * nrow:
|
||||
# img = np.concatenate([img, np.zeros((nrow*nrow-N,C,H,W))],axis=0)
|
||||
# img2 = img.reshape([-1,W*nrow*nrow,H])
|
||||
# img = img2[:,:W*nrow,:]
|
||||
# for i in range(1,nrow):
|
||||
# img = np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2)
|
||||
# min_ = img.min()
|
||||
# max_ = img.max()
|
||||
# img = (img - min_) / (max_ - min_) * 255
|
||||
# img = img.transpose((1,2,0))
|
||||
# cv2.imwrite(path,img)
|
||||
|
||||
def sample_images(batches_done):
|
||||
"""Saves a generated sample from the test set"""
|
||||
imgs = next(iter(val_dataloader))
|
||||
X1 = imgs[0]
|
||||
X2 = imgs[1]
|
||||
X1 = imgs[0].stop_grad()
|
||||
X2 = imgs[1].stop_grad()
|
||||
E1.eval()
|
||||
E2.eval()
|
||||
G1.eval()
|
||||
G2.eval()
|
||||
_, Z1 = E1(X1)
|
||||
_, Z2 = E2(X2)
|
||||
fake_X1 = G1(Z2)
|
||||
fake_X2 = G2(Z1)
|
||||
img_sample = jt.contrib.concat((X1, fake_X2, X2, fake_X1), 0)
|
||||
save_image(img_sample.numpy(), "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5)
|
||||
save_image(torch.Tensor(img_sample.numpy()), "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)
|
||||
|
||||
|
||||
def compute_kl(mu):
|
||||
|
@ -123,18 +135,22 @@ def compute_kl(mu):
|
|||
prev_time = time.time()
|
||||
for epoch in range(opt.epoch, opt.n_epochs):
|
||||
for i, batch in enumerate(dataloader):
|
||||
jt.sync_all(True)
|
||||
# Set model input
|
||||
X1 = batch[0]
|
||||
X2 = batch[1]
|
||||
X1 = batch[0].stop_grad()
|
||||
X2 = batch[1].stop_grad()
|
||||
|
||||
# Adversarial ground truths
|
||||
valid = jt.array(np.ones((X1.size(0), *D1.output_shape))).float32().stop_grad()
|
||||
fake = jt.array(np.zeros((X1.size(0), *D1.output_shape))).float32().stop_grad()
|
||||
valid = jt.ones((X1.size(0), *D1.output_shape)).stop_grad()
|
||||
fake = jt.zeros((X1.size(0), *D1.output_shape)).stop_grad()
|
||||
# -------------------------------
|
||||
# Train Encoders and Generators
|
||||
# -------------------------------
|
||||
|
||||
E1.train()
|
||||
E2.train()
|
||||
G1.train()
|
||||
G2.train()
|
||||
|
||||
# Get shared latent representation
|
||||
mu1, Z1 = E1(X1)
|
||||
mu2, Z2 = E2(X2)
|
||||
|
@ -218,8 +234,8 @@ for epoch in range(opt.epoch, opt.n_epochs):
|
|||
)
|
||||
|
||||
# If at sample interval save image
|
||||
# if batches_done % opt.sample_interval == 0:
|
||||
# sample_images(batches_done)
|
||||
if batches_done % opt.sample_interval == 0:
|
||||
sample_images(batches_done)
|
||||
|
||||
if epoch >= opt.decay_epoch:
|
||||
optimizer_G.lr = opt.lr * (opt.n_epochs - epoch - 1) / (opt.n_epochs - opt.decay_epoch)
|
||||
|
|
Loading…
Reference in New Issue