269 lines
11 KiB
Python
269 lines
11 KiB
Python
import os
|
|
from glob import glob
|
|
|
|
import imageio
|
|
import torch
|
|
import torchvision
|
|
import wandb
|
|
from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan
|
|
from loaders import load_vqgan
|
|
from PIL import Image
|
|
from torch import nn
|
|
|
|
from transformers import CLIPModel, CLIPTokenizerFast
|
|
from utils import get_device, get_timestamp, show_pil
|
|
|
|
|
|
class ProcessorGradientFlow:
|
|
"""
|
|
This wraps the huggingface CLIP processor to allow backprop through the image processing step.
|
|
The original processor forces conversion to PIL images, which is faster for image processing but breaks gradient flow.
|
|
We call the original processor to get the text embeddings, but use our own image processing to keep images as torch tensors.
|
|
"""
|
|
|
|
def __init__(self, device: str = "cpu", clip_model: str = "openai/clip-vit-large-patch14") -> None:
|
|
self.device = device
|
|
self.tokenizer = CLIPTokenizerFast.from_pretrained(clip_model)
|
|
self.image_mean = [0.48145466, 0.4578275, 0.40821073]
|
|
self.image_std = [0.26862954, 0.26130258, 0.27577711]
|
|
self.normalize = torchvision.transforms.Normalize(self.image_mean, self.image_std)
|
|
self.resize = torchvision.transforms.Resize(224)
|
|
self.center_crop = torchvision.transforms.CenterCrop(224)
|
|
|
|
def preprocess_img(self, images):
|
|
images = self.resize(images)
|
|
images = self.center_crop(images)
|
|
images = self.normalize(images)
|
|
return images
|
|
|
|
def __call__(self, text=None, images=None, **kwargs):
|
|
encoding = self.tokenizer(text=text, **kwargs)
|
|
encoding["pixel_values"] = self.preprocess_img(images)
|
|
encoding = {key: value.to(self.device) for (key, value) in encoding.items()}
|
|
return encoding
|
|
|
|
|
|
class VQGAN_CLIP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
iterations=10,
|
|
lr=0.01,
|
|
vqgan=None,
|
|
vqgan_config=None,
|
|
vqgan_checkpoint=None,
|
|
clip=None,
|
|
clip_preprocessor=None,
|
|
device=None,
|
|
log=False,
|
|
save_vector=True,
|
|
return_val="image",
|
|
quantize=True,
|
|
save_intermediate=False,
|
|
show_intermediate=False,
|
|
make_grid=False,
|
|
) -> None:
|
|
"""
|
|
Instantiate a VQGAN_CLIP model. If you want to use a custom VQGAN model, pass it as vqgan.
|
|
"""
|
|
super().__init__()
|
|
self.latent = None
|
|
self.device = device if device else get_device()
|
|
if vqgan:
|
|
self.vqgan = vqgan
|
|
else:
|
|
self.vqgan = load_vqgan(self.device, conf_path=vqgan_config, ckpt_path=vqgan_checkpoint)
|
|
self.vqgan.eval()
|
|
if clip:
|
|
self.clip = clip
|
|
else:
|
|
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
self.clip.to(self.device)
|
|
self.clip_preprocessor = ProcessorGradientFlow(device=self.device)
|
|
|
|
self.iterations = iterations
|
|
self.lr = lr
|
|
self.log = log
|
|
self.make_grid = make_grid
|
|
self.return_val = return_val
|
|
self.quantize = quantize
|
|
self.latent_dim = self.vqgan.decoder.z_shape
|
|
|
|
def make_animation(self, input_path=None, output_path=None, total_duration=5, extend_frames=True):
|
|
"""
|
|
Make an animation from the intermediate images saved during generation.
|
|
By default, uses the images from the most recent generation created by the generate function.
|
|
If you want to use images from a different generation, pass the path to the folder containing the images as input_path.
|
|
"""
|
|
images = []
|
|
if output_path is None:
|
|
output_path = "./animation.gif"
|
|
if input_path is None:
|
|
input_path = self.save_path
|
|
paths = sorted(glob(input_path + "/*"))
|
|
if not len(paths):
|
|
raise ValueError(
|
|
"No images found in save path, aborting (did you pass save_intermediate=True to the generate"
|
|
" function?)"
|
|
)
|
|
if len(paths) == 1:
|
|
print("Only one image found in save path, (did you pass save_intermediate=True to the generate function?)")
|
|
frame_duration = total_duration / len(paths)
|
|
durations = [frame_duration] * len(paths)
|
|
if extend_frames:
|
|
durations[0] = 1.5
|
|
durations[-1] = 3
|
|
for file_name in paths:
|
|
if file_name.endswith(".png"):
|
|
images.append(imageio.imread(file_name))
|
|
imageio.mimsave(output_path, images, duration=durations)
|
|
print(f"gif saved to {output_path}")
|
|
|
|
def _get_latent(self, path=None, img=None):
|
|
if not (path or img):
|
|
raise ValueError("Input either path or tensor")
|
|
if img is not None:
|
|
raise NotImplementedError
|
|
x = preprocess(Image.open(path), target_image_size=256).to(self.device)
|
|
x_processed = preprocess_vqgan(x)
|
|
z, *_ = self.vqgan.encode(x_processed)
|
|
return z
|
|
|
|
def _add_vector(self, transform_vector):
|
|
"""Add a vector transform to the base latent and returns the resulting image."""
|
|
base_latent = self.latent.detach().requires_grad_()
|
|
trans_latent = base_latent + transform_vector
|
|
if self.quantize:
|
|
z_q, *_ = self.vqgan.quantize(trans_latent)
|
|
else:
|
|
z_q = trans_latent
|
|
return self.vqgan.decode(z_q)
|
|
|
|
def _get_clip_similarity(self, prompts, image, weights=None):
|
|
clip_inputs = self.clip_preprocessor(text=prompts, images=image, return_tensors="pt", padding=True)
|
|
clip_outputs = self.clip(**clip_inputs)
|
|
similarity_logits = clip_outputs.logits_per_image
|
|
if weights is not None:
|
|
similarity_logits = similarity_logits * weights
|
|
return similarity_logits.sum()
|
|
|
|
def _get_clip_loss(self, pos_prompts, neg_prompts, image):
|
|
pos_logits = self._get_clip_similarity(pos_prompts["prompts"], image, weights=(1 / pos_prompts["weights"]))
|
|
if neg_prompts:
|
|
neg_logits = self._get_clip_similarity(neg_prompts["prompts"], image, weights=neg_prompts["weights"])
|
|
else:
|
|
neg_logits = torch.tensor([1], device=self.device)
|
|
loss = -torch.log(pos_logits) + torch.log(neg_logits)
|
|
return loss
|
|
|
|
def _optimize_CLIP(self, original_img, pos_prompts, neg_prompts):
|
|
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
|
optim = torch.optim.Adam([vector], lr=self.lr)
|
|
|
|
for i in range(self.iterations):
|
|
optim.zero_grad()
|
|
transformed_img = self._add_vector(vector)
|
|
processed_img = loop_post_process(transformed_img)
|
|
clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, processed_img)
|
|
print("CLIP loss", clip_loss)
|
|
if self.log:
|
|
wandb.log({"CLIP Loss": clip_loss})
|
|
clip_loss.backward(retain_graph=True)
|
|
optim.step()
|
|
if self.return_val == "image":
|
|
yield custom_to_pil(transformed_img[0])
|
|
else:
|
|
yield vector
|
|
|
|
def _init_logging(self, positive_prompts, negative_prompts, image_path):
|
|
wandb.init(reinit=True, project="face-editor")
|
|
wandb.config.update({"Positive Prompts": positive_prompts})
|
|
wandb.config.update({"Negative Prompts": negative_prompts})
|
|
wandb.config.update({"lr": self.lr, "iterations": self.iterations})
|
|
if image_path:
|
|
image = Image.open(image_path)
|
|
image = image.resize((256, 256))
|
|
wandb.log("Original Image", wandb.Image(image))
|
|
|
|
def process_prompts(self, prompts):
|
|
if not prompts:
|
|
return []
|
|
processed_prompts = []
|
|
weights = []
|
|
if isinstance(prompts, str):
|
|
prompts = [prompt.strip() for prompt in prompts.split("|")]
|
|
for prompt in prompts:
|
|
if isinstance(prompt, (tuple, list)):
|
|
processed_prompt = prompt[0]
|
|
weight = float(prompt[1])
|
|
elif ":" in prompt:
|
|
processed_prompt, weight = prompt.split(":")
|
|
weight = float(weight)
|
|
else:
|
|
processed_prompt = prompt
|
|
weight = 1.0
|
|
processed_prompts.append(processed_prompt)
|
|
weights.append(weight)
|
|
return {
|
|
"prompts": processed_prompts,
|
|
"weights": torch.tensor(weights, device=self.device),
|
|
}
|
|
|
|
def generate(
|
|
self,
|
|
pos_prompts,
|
|
neg_prompts=None,
|
|
image_path=None,
|
|
show_intermediate=True,
|
|
save_intermediate=False,
|
|
show_final=True,
|
|
save_final=True,
|
|
save_path=None,
|
|
):
|
|
"""Generate an image from the given prompts.
|
|
If image_path is provided, the image is used as a starting point for the optimization.
|
|
If image_path is not provided, a random latent vector is used as a starting point.
|
|
You must provide at least one positive prompt, and optionally provide negative prompts.
|
|
Prompts must be formatted in one of the following ways:
|
|
- A single prompt as a string, e.g "A smiling woman"
|
|
- A set of prompts separated by pipes: "A smiling woman | a woman with brown hair"
|
|
- A set of prompts and their weights separated by colons: "A smiling woman:1 | a woman with brown hair: 3" (default weight is 1)
|
|
- A list of prompts, e.g ["A smiling woman", "a woman with brown hair"]
|
|
- A list of prompts and weights, e.g [("A smiling woman", 1), ("a woman with brown hair", 3)]
|
|
"""
|
|
if image_path:
|
|
self.latent = self._get_latent(image_path)
|
|
else:
|
|
self.latent = torch.randn(self.latent_dim, device=self.device)
|
|
if self.log:
|
|
self._init_logging(pos_prompts, neg_prompts, image_path)
|
|
|
|
assert pos_prompts, "You must provide at least one positive prompt."
|
|
pos_prompts = self.process_prompts(pos_prompts)
|
|
neg_prompts = self.process_prompts(neg_prompts)
|
|
if save_final and save_path is None:
|
|
save_path = os.path.join("./outputs/", "_".join(pos_prompts["prompts"]))
|
|
if not os.path.exists(save_path):
|
|
os.makedirs(save_path)
|
|
else:
|
|
save_path = save_path + "_" + get_timestamp()
|
|
os.makedirs(save_path)
|
|
self.save_path = save_path
|
|
|
|
original_img = self.vqgan.decode(self.latent)[0]
|
|
if show_intermediate:
|
|
print("Original Image")
|
|
show_pil(custom_to_pil(original_img))
|
|
|
|
original_img = loop_post_process(original_img)
|
|
for iter, transformed_img in enumerate(self._optimize_CLIP(original_img, pos_prompts, neg_prompts)):
|
|
if show_intermediate:
|
|
show_pil(transformed_img)
|
|
if save_intermediate:
|
|
transformed_img.save(os.path.join(self.save_path, f"iter_{iter:03d}.png"))
|
|
if self.log:
|
|
wandb.log({"Image": wandb.Image(transformed_img)})
|
|
if show_final:
|
|
show_pil(transformed_img)
|
|
if save_final:
|
|
transformed_img.save(os.path.join(self.save_path, f"iter_{iter:03d}_final.png"))
|