Support ControlNet (#153)
* add controlnet tentatively * add controlnet in python code * implement swift part * support 8-bit quantization * add controlnet unload when reduce memory * remove irrelevant changes * add more description about controlnet option in swift * fix some for pr and update README * pre-allocate zero shapedArray + make multi-controlnet faster
This commit is contained in:
parent
d1a6888d43
commit
7f65e1c84b
14
README.md
14
README.md
|
@ -139,6 +139,10 @@ This generally takes 15-20 minutes on an M1 MacBook Pro. Upon successful executi
|
|||
|
||||
- `--check-output-correctness`: Compares original PyTorch model's outputs to final Core ML model's outputs. This flag increases RAM consumption significantly so it is recommended only for debugging purposes.
|
||||
|
||||
- `--convert-controlnet`: Converts ControlNet models specified after this option. This can also convert multiple models if you specify like `--convert-controlnet lllyasviel/sd-controlnet-mlsd lllyasviel/sd-controlnet-depth`.
|
||||
|
||||
- `--unet-support-controlnet`: enables a converted UNet model to receive additional inputs from ControlNet. This is required for generating image with using ControlNet and saved with a different name, `*_control-unet.mlpackage`, distinct from normal UNet. On the other hand, this UNet model can not work without ControlNet. Please use normal UNet for just txt2img.
|
||||
|
||||
</details>
|
||||
|
||||
## <a name="image-generation-with-python"></a> Image Generation with Python
|
||||
|
@ -157,6 +161,8 @@ Please refer to the help menu for all available arguments: `python -m python_cor
|
|||
- `--model-version`: If you overrode the default model version while converting models to Core ML, you will need to specify the same model version here.
|
||||
- `--compute-unit`: Note that the most performant compute unit for this particular implementation may differ across different hardware. `CPU_AND_GPU` or `CPU_AND_NE` may be faster than `ALL`. Please refer to the [Performance Benchmark](#performance-benchmark) section for further guidance.
|
||||
- `--scheduler`: If you would like to experiment with different schedulers, you may specify it here. For available options, please see the help menu. You may also specify a custom number of inference steps by `--num-inference-steps` which defaults to 50.
|
||||
- `--controlnet`: ControlNet models specified with this option are used in image generation. Use this option in the format `--controlnet lllyasviel/sd-controlnet-mlsd lllyasviel/sd-controlnet-depth` and make sure to use `--controlnet-inputs` in conjunction.
|
||||
- `--controlnet-inputs`: Image inputs corresponding to each ControlNet model. Please provide image paths in same order as models in `--controlnet`, for example: `--controlnet-inputs image_mlsd image_depth`.
|
||||
|
||||
</details>
|
||||
|
||||
|
@ -228,6 +234,14 @@ Optionally, it may also include the safety checker model that some versions of S
|
|||
|
||||
- `SafetyChecker.mlmodelc`
|
||||
|
||||
Optionally, for ControlNet:
|
||||
|
||||
- `ControlledUNet.mlmodelc` or `ControlledUnetChunk1.mlmodelc` & `ControlledUnetChunk2.mlmodelc` (enabled to receive ControlNet values)
|
||||
- `controlnet/` (directory containing ControlNet models)
|
||||
- `LllyasvielSdControlnetMlsd.mlmodelc` (for example, from lllyasviel/sd-controlnet-mlsd)
|
||||
- `LllyasvielSdControlnetDepth.mlmodelc` (for example, from lllyasviel/sd-controlnet-depth)
|
||||
- Other models you converted
|
||||
|
||||
Note that the chunked version of Unet is checked for first. Only if it is not present will the full `Unet.mlmodelc` be loaded. Chunking is required for iOS and iPadOS and not necessary for macOS.
|
||||
|
||||
</details>
|
||||
|
|
|
@ -0,0 +1,244 @@
|
|||
#
|
||||
# For licensing see accompanying LICENSE.md file.
|
||||
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers import ModelMixin
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .unet import Timesteps, TimestepEmbedding, get_down_block, UNetMidBlock2DCrossAttn, linear_to_conv2d_map
|
||||
|
||||
class ControlNetConditioningEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_embedding_channels,
|
||||
conditioning_channels=3,
|
||||
block_out_channels=(16, 32, 96, 256),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
||||
for i in range(len(block_out_channels) - 1):
|
||||
channel_in = block_out_channels[i]
|
||||
channel_out = block_out_channels[i + 1]
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
||||
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, conditioning):
|
||||
embedding = self.conv_in(conditioning)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
for block in self.blocks:
|
||||
embedding = block(embedding)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
embedding = self.conv_out(embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
flip_sin_to_cos=True,
|
||||
freq_shift=0,
|
||||
down_block_types=(
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
only_cross_attention=False,
|
||||
block_out_channels=(320, 640, 1280, 1280),
|
||||
layers_per_block=2,
|
||||
downsample_padding=1,
|
||||
mid_block_scale_factor=1,
|
||||
act_fn="silu",
|
||||
norm_num_groups=32,
|
||||
norm_eps=1e-5,
|
||||
cross_attention_dim=1280,
|
||||
attention_head_dim=8,
|
||||
use_linear_projection=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
conditioning_embedding_out_channels=(16, 32, 96, 256),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Check inputs
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||
)
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
)
|
||||
|
||||
# control net conditioning embedding
|
||||
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
block_out_channels=conditioning_embedding_out_channels,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.controlnet_down_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
# mid
|
||||
mid_block_channel = block_out_channels[-1]
|
||||
|
||||
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
||||
self.controlnet_mid_block = controlnet_block
|
||||
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=mid_block_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
def get_num_residuals(self):
|
||||
num_res = 2 # initial sample + mid block
|
||||
for down_block in self.down_blocks:
|
||||
num_res += len(down_block.resnets)
|
||||
if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None:
|
||||
num_res += len(down_block.downsamplers)
|
||||
return num_res
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
controlnet_cond,
|
||||
):
|
||||
# 1. time
|
||||
t_emb = self.time_proj(timestep)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
|
||||
sample += controlnet_cond
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
# 5. Control net blocks
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples += (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
return down_block_res_samples, mid_block_res_sample
|
|
@ -98,5 +98,22 @@ def _load_mlpackage(submodule_name, mlpackages_dir, model_version,
|
|||
|
||||
return CoreMLModel(mlpackage_path, compute_unit)
|
||||
|
||||
def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):
|
||||
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
|
||||
"""
|
||||
model_name = model_version.replace("/", "_")
|
||||
|
||||
logger.info(f"Loading controlnet_{model_name} mlpackage")
|
||||
|
||||
fname = f"ControlNet_{model_name}.mlpackage"
|
||||
|
||||
mlpackage_path = os.path.join(mlpackages_dir, fname)
|
||||
|
||||
if not os.path.exists(mlpackage_path):
|
||||
raise FileNotFoundError(
|
||||
f"controlnet_{model_name} CoreML model doesn't exist at {mlpackage_path}")
|
||||
|
||||
return CoreMLModel(mlpackage_path, compute_unit)
|
||||
|
||||
def get_available_compute_units():
|
||||
return tuple(cu for cu in ct.ComputeUnit._member_names_)
|
||||
|
|
|
@ -32,6 +32,7 @@ import os
|
|||
from python_coreml_stable_diffusion.coreml_model import (
|
||||
CoreMLModel,
|
||||
_load_mlpackage,
|
||||
_load_mlpackage_controlnet,
|
||||
get_available_compute_units,
|
||||
)
|
||||
|
||||
|
@ -39,6 +40,7 @@ import time
|
|||
import torch # Only used for `torch.from_tensor` in `pipe.scheduler.step()`
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
from typing import List, Optional, Union
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class CoreMLStableDiffusionPipeline(DiffusionPipeline):
|
||||
|
@ -60,6 +62,7 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
|
|||
LMSDiscreteScheduler,
|
||||
PNDMScheduler],
|
||||
tokenizer: CLIPTokenizer,
|
||||
controlnet: Optional[List[CoreMLModel]],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -88,6 +91,8 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
|
|||
self.unet = unet
|
||||
self.unet.in_channels = self.unet.expected_inputs["sample"]["shape"][1]
|
||||
|
||||
self.controlnet = controlnet
|
||||
|
||||
self.vae_decoder = vae_decoder
|
||||
|
||||
VAE_DECODER_UPSAMPLE_FACTOR = 8
|
||||
|
@ -168,6 +173,33 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
|
|||
|
||||
return text_embeddings
|
||||
|
||||
def run_controlnet(self,
|
||||
sample,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
controlnet_cond,
|
||||
output_dtype=np.float16):
|
||||
if not self.controlnet:
|
||||
raise ValueError(
|
||||
"Conditions for controlnet are given but the pipeline has no controlnet modules")
|
||||
|
||||
for i, (module, cond) in enumerate(zip(self.controlnet, controlnet_cond)):
|
||||
module_outputs = module(
|
||||
sample=sample.astype(np.float16),
|
||||
timestep=timestep.astype(np.float16),
|
||||
encoder_hidden_states=encoder_hidden_states.astype(np.float16),
|
||||
controlnet_cond=cond.astype(np.float16),
|
||||
)
|
||||
if i == 0:
|
||||
outputs = module_outputs
|
||||
else:
|
||||
for key in outputs.keys():
|
||||
outputs[key] += module_outputs[key]
|
||||
|
||||
outputs = {k: v.astype(output_dtype) for k, v in outputs.items()}
|
||||
|
||||
return outputs
|
||||
|
||||
def run_safety_checker(self, image):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(
|
||||
|
@ -222,6 +254,19 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
|
|||
|
||||
return latents
|
||||
|
||||
def prepare_control_cond(self,
|
||||
controlnet_cond,
|
||||
do_classifier_free_guidance,
|
||||
batch_size,
|
||||
num_images_per_prompt):
|
||||
processed_cond_list = []
|
||||
for cond in controlnet_cond:
|
||||
cond = np.stack([cond] * batch_size * num_images_per_prompt)
|
||||
if do_classifier_free_guidance:
|
||||
cond = np.concatenate([cond] * 2)
|
||||
processed_cond_list.append(cond)
|
||||
return processed_cond_list
|
||||
|
||||
def check_inputs(self, prompt, height, width, callback_steps):
|
||||
if height != self.height or width != self.width:
|
||||
logger.warning(
|
||||
|
@ -276,6 +321,7 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
|
|||
return_dict=True,
|
||||
callback=None,
|
||||
callback_steps=1,
|
||||
controlnet_cond=None,
|
||||
**kwargs,
|
||||
):
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
|
@ -305,7 +351,7 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
|
|||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
# 5. Prepare latent variables and controlnet cond
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
|
@ -315,6 +361,14 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
|
|||
latents,
|
||||
)
|
||||
|
||||
if controlnet_cond:
|
||||
controlnet_cond = self.prepare_control_cond(
|
||||
controlnet_cond,
|
||||
do_classifier_free_guidance,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
|
||||
|
||||
|
@ -326,11 +380,23 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
|
|||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t)
|
||||
|
||||
# controlnet
|
||||
if controlnet_cond:
|
||||
additional_residuals = self.run_controlnet(
|
||||
sample=latent_model_input,
|
||||
timestep=np.array([t, t]),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
controlnet_cond=controlnet_cond,
|
||||
)
|
||||
else:
|
||||
additional_residuals = {}
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input.astype(np.float16),
|
||||
timestep=np.array([t, t], np.float16),
|
||||
encoder_hidden_states=text_embeddings.astype(np.float16),
|
||||
**additional_residuals,
|
||||
)["noise_pred"]
|
||||
|
||||
# perform guidance
|
||||
|
@ -385,7 +451,8 @@ def get_coreml_pipe(pytorch_pipe,
|
|||
model_version,
|
||||
compute_unit,
|
||||
delete_original_pipe=True,
|
||||
scheduler_override=None):
|
||||
scheduler_override=None,
|
||||
controlnet_models=None):
|
||||
""" Initializes and returns a `CoreMLStableDiffusionPipeline` from an original
|
||||
diffusers PyTorch pipeline
|
||||
"""
|
||||
|
@ -417,6 +484,22 @@ def get_coreml_pipe(pytorch_pipe,
|
|||
gc.collect()
|
||||
logger.info("Removed PyTorch pipe to reduce peak memory consumption")
|
||||
|
||||
if controlnet_models:
|
||||
model_names_to_load.remove("unet")
|
||||
coreml_pipe_kwargs["unet"] = _load_mlpackage(
|
||||
"control-unet",
|
||||
mlpackages_dir,
|
||||
model_version,
|
||||
compute_unit,
|
||||
)
|
||||
coreml_pipe_kwargs["controlnet"] = [_load_mlpackage_controlnet(
|
||||
mlpackages_dir,
|
||||
model_version,
|
||||
compute_unit,
|
||||
) for model_version in controlnet_models]
|
||||
else:
|
||||
coreml_pipe_kwargs["controlnet"] = None
|
||||
|
||||
# Load Core ML models
|
||||
logger.info(f"Loading Core ML models in memory from {mlpackages_dir}")
|
||||
coreml_pipe_kwargs.update({
|
||||
|
@ -453,6 +536,11 @@ def get_image_path(args, **override_kwargs):
|
|||
|
||||
return os.path.join(out_folder, out_fname + ".png")
|
||||
|
||||
def prepare_controlnet_cond(image_path, height, width):
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
image = image.resize((height, width), resample=Image.LANCZOS)
|
||||
image = np.array(image).transpose(2, 0, 1) / 255.0
|
||||
return image
|
||||
|
||||
def main(args):
|
||||
logger.info(f"Setting random seed to {args.seed}")
|
||||
|
@ -472,7 +560,17 @@ def main(args):
|
|||
mlpackages_dir=args.i,
|
||||
model_version=args.model_version,
|
||||
compute_unit=args.compute_unit,
|
||||
scheduler_override=user_specified_scheduler)
|
||||
scheduler_override=user_specified_scheduler,
|
||||
controlnet_models=args.controlnet)
|
||||
|
||||
if args.controlnet:
|
||||
controlnet_cond = []
|
||||
for i, _ in enumerate(args.controlnet):
|
||||
image_path = args.controlnet_inputs[i]
|
||||
image = prepare_controlnet_cond(image_path, coreml_pipe.height, coreml_pipe.width)
|
||||
controlnet_cond.append(image)
|
||||
else:
|
||||
controlnet_cond = None
|
||||
|
||||
logger.info("Beginning image generation.")
|
||||
image = coreml_pipe(
|
||||
|
@ -480,7 +578,8 @@ def main(args):
|
|||
height=coreml_pipe.height,
|
||||
width=coreml_pipe.width,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
guidance_scale=args.guidance_scale
|
||||
guidance_scale=args.guidance_scale,
|
||||
controlnet_cond=controlnet_cond,
|
||||
)
|
||||
|
||||
out_path = get_image_path(args)
|
||||
|
@ -535,6 +634,18 @@ if __name__ == "__main__":
|
|||
default=7.5,
|
||||
type=float,
|
||||
help="Controls the influence of the text prompt on sampling process (0=random images)")
|
||||
parser.add_argument(
|
||||
"--controlnet",
|
||||
nargs="*",
|
||||
type=str,
|
||||
help=("Enables ControlNet and use control-unet instead of unet for additional inputs. "
|
||||
"For Multi-Controlnet, provide the model names separated by spaces."))
|
||||
parser.add_argument(
|
||||
"--controlnet-inputs",
|
||||
nargs="*",
|
||||
type=str,
|
||||
help=("Image paths for ControlNet inputs. "
|
||||
"Please enter images corresponding to each controlnet provided at --controlnet option in same order."))
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
|
|
@ -3,13 +3,13 @@
|
|||
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
|
||||
from python_coreml_stable_diffusion import unet
|
||||
from python_coreml_stable_diffusion import unet, controlnet
|
||||
|
||||
import argparse
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
import coremltools as ct
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import StableDiffusionPipeline, ControlNetModel
|
||||
import gc
|
||||
|
||||
import logging
|
||||
|
@ -24,6 +24,8 @@ from python_coreml_stable_diffusion import chunk_mlprogram
|
|||
import requests
|
||||
import shutil
|
||||
import time
|
||||
import re
|
||||
import pathlib
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -84,7 +86,6 @@ def report_correctness(original_outputs, final_outputs, log_prefix):
|
|||
)
|
||||
return final_psnr
|
||||
|
||||
|
||||
def _get_out_path(args, submodule_name):
|
||||
fname = f"Stable_Diffusion_version_{args.model_version}_{submodule_name}.mlpackage"
|
||||
fname = fname.replace("/", "_")
|
||||
|
@ -100,8 +101,10 @@ def _save_mlpackage(model, output_path):
|
|||
|
||||
|
||||
def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
|
||||
output_names, args):
|
||||
out_path = _get_out_path(args, submodule_name)
|
||||
output_names, args, out_path=None):
|
||||
|
||||
if out_path is None:
|
||||
out_path = _get_out_path(args, submodule_name)
|
||||
|
||||
if os.path.exists(out_path):
|
||||
logger.info(f"Skipping export because {out_path} already exists")
|
||||
|
@ -140,21 +143,32 @@ def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
|
|||
|
||||
def quantize_weights_to_8bits(args):
|
||||
for model_name in [
|
||||
"text_encoder", "vae_decoder", "vae_encoder", "unet", "unet_chunk1",
|
||||
"unet_chunk2", "safety_checker"
|
||||
"text_encoder", "vae_decoder", "vae_encoder", "unet", "unet_chunk1", "unet_chunk2",
|
||||
"control-unet", "control-unet_chunk1", "control-unet_chunk2", "safety_checker"
|
||||
]:
|
||||
out_path = _get_out_path(args, model_name)
|
||||
if os.path.exists(out_path):
|
||||
logger.info(f"Quantizing {model_name}")
|
||||
mlmodel = ct.models.MLModel(out_path,
|
||||
compute_units=ct.ComputeUnit.CPU_ONLY)
|
||||
mlmodel = ct.compression_utils.affine_quantize_weights(
|
||||
mlmodel, mode="linear")
|
||||
mlmodel.save(out_path)
|
||||
logger.info("Done")
|
||||
else:
|
||||
logger.info(
|
||||
f"Skipped quantizing {model_name} (Not found at {out_path})")
|
||||
_quantize_and_save_8bits_model(out_path, model_name)
|
||||
|
||||
if args.convert_controlnet:
|
||||
for controlnet_model_version in args.convert_controlnet:
|
||||
controlnet_model_name = controlnet_model_version.replace("/", "_")
|
||||
fname = f"ControlNet_{controlnet_model_name}.mlpackage"
|
||||
out_path = os.path.join(args.o, fname)
|
||||
_quantize_and_save_8bits_model(out_path, controlnet_model_name)
|
||||
|
||||
|
||||
def _quantize_and_save_8bits_model(out_path, model_name):
|
||||
if os.path.exists(out_path):
|
||||
logger.info(f"Quantizing {model_name}")
|
||||
mlmodel = ct.models.MLModel(out_path,
|
||||
compute_units=ct.ComputeUnit.CPU_ONLY)
|
||||
mlmodel = ct.compression_utils.affine_quantize_weights(
|
||||
mlmodel, mode="linear")
|
||||
mlmodel.save(out_path)
|
||||
logger.info("Done")
|
||||
else:
|
||||
logger.info(
|
||||
f"Skipped quantizing {model_name} (Not found at {out_path})")
|
||||
|
||||
|
||||
def _compile_coreml_model(source_model_path, output_dir, final_name):
|
||||
|
@ -194,6 +208,9 @@ def bundle_resources_for_swift_cli(args):
|
|||
("unet", "Unet"),
|
||||
("unet_chunk1", "UnetChunk1"),
|
||||
("unet_chunk2", "UnetChunk2"),
|
||||
("control-unet", "ControlledUnet"),
|
||||
("control-unet_chunk1", "ControlledUnetChunk1"),
|
||||
("control-unet_chunk2", "ControlledUnetChunk2"),
|
||||
("safety_checker", "SafetyChecker")]:
|
||||
source_path = _get_out_path(args, source_name)
|
||||
if os.path.exists(source_path):
|
||||
|
@ -204,6 +221,23 @@ def bundle_resources_for_swift_cli(args):
|
|||
logger.warning(
|
||||
f"{source_path} not found, skipping compilation to {target_name}.mlmodelc"
|
||||
)
|
||||
|
||||
if args.convert_controlnet:
|
||||
for controlnet_model_version in args.convert_controlnet:
|
||||
controlnet_model_name = controlnet_model_version.replace("/", "_")
|
||||
fname = f"ControlNet_{controlnet_model_name}.mlpackage"
|
||||
source_path = os.path.join(args.o, fname)
|
||||
controlnet_dir = os.path.join(resources_dir, "controlnet")
|
||||
target_name = "".join([word.title() for word in re.split('_|-', controlnet_model_name)])
|
||||
|
||||
if os.path.exists(source_path):
|
||||
target_path = _compile_coreml_model(source_path, controlnet_dir,
|
||||
target_name)
|
||||
logger.info(f"Compiled {source_path} to {target_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"{source_path} not found, skipping compilation to {target_name}.mlmodelc"
|
||||
)
|
||||
|
||||
# Fetch and save vocabulary JSON file for text tokenizer
|
||||
logger.info("Downloading and saving tokenizer vocab.json")
|
||||
|
@ -543,7 +577,12 @@ def convert_vae_encoder(pipe, args):
|
|||
def convert_unet(pipe, args):
|
||||
""" Converts the UNet component of Stable Diffusion
|
||||
"""
|
||||
out_path = _get_out_path(args, "unet")
|
||||
if args.unet_support_controlnet:
|
||||
unet_name = "control-unet"
|
||||
else:
|
||||
unet_name = "unet"
|
||||
|
||||
out_path = _get_out_path(args, unet_name)
|
||||
|
||||
# Check if Unet was previously exported and then chunked
|
||||
unet_chunks_exist = all(
|
||||
|
@ -559,13 +598,6 @@ def convert_unet(pipe, args):
|
|||
|
||||
# If original Unet does not exist, export it from PyTorch+diffusers
|
||||
elif not os.path.exists(out_path):
|
||||
# Register the selected attention implementation globally
|
||||
unet.ATTENTION_IMPLEMENTATION_IN_EFFECT = unet.AttentionImplementations[
|
||||
args.attention_implementation]
|
||||
logger.info(
|
||||
f"Attention implementation in effect: {unet.ATTENTION_IMPLEMENTATION_IN_EFFECT}"
|
||||
)
|
||||
|
||||
# Prepare sample input shapes and values
|
||||
batch_size = 2 # for classifier-free guidance
|
||||
sample_shape = (
|
||||
|
@ -598,16 +630,6 @@ def convert_unet(pipe, args):
|
|||
(batch_size)).to(torch.float32)),
|
||||
("encoder_hidden_states", torch.rand(*encoder_hidden_states_shape))
|
||||
])
|
||||
sample_unet_inputs_spec = {
|
||||
k: (v.shape, v.dtype)
|
||||
for k, v in sample_unet_inputs.items()
|
||||
}
|
||||
logger.info(f"Sample inputs spec: {sample_unet_inputs_spec}")
|
||||
|
||||
# Initialize reference unet
|
||||
reference_unet = unet.UNet2DConditionModel(**pipe.unet.config).eval()
|
||||
load_state_dict_summary = reference_unet.load_state_dict(
|
||||
pipe.unet.state_dict())
|
||||
|
||||
# Prepare inputs
|
||||
baseline_sample_unet_inputs = deepcopy(sample_unet_inputs)
|
||||
|
@ -615,6 +637,52 @@ def convert_unet(pipe, args):
|
|||
"encoder_hidden_states"] = baseline_sample_unet_inputs[
|
||||
"encoder_hidden_states"].squeeze(2).transpose(1, 2)
|
||||
|
||||
# Initialize reference unet
|
||||
reference_unet = unet.UNet2DConditionModel(**pipe.unet.config).eval()
|
||||
load_state_dict_summary = reference_unet.load_state_dict(
|
||||
pipe.unet.state_dict())
|
||||
|
||||
if args.unet_support_controlnet:
|
||||
from .unet import calculate_conv2d_output_shape
|
||||
additional_residuals_shapes = []
|
||||
in_size = pipe.unet.config.sample_size
|
||||
|
||||
# conv_in
|
||||
out_h, out_w = calculate_conv2d_output_shape(in_size, in_size, reference_unet.conv_in)
|
||||
additional_residuals_shapes.append(
|
||||
(batch_size, reference_unet.conv_in.out_channels, out_h, out_w))
|
||||
|
||||
# down_blocks
|
||||
for down_block in reference_unet.down_blocks:
|
||||
additional_residuals_shapes += [
|
||||
(batch_size, resnet.out_channels, out_h, out_w) for resnet in down_block.resnets
|
||||
]
|
||||
if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None:
|
||||
for downsampler in down_block.downsamplers:
|
||||
out_h, out_w = calculate_conv2d_output_shape(out_h, out_w, downsampler.conv)
|
||||
additional_residuals_shapes.append(
|
||||
(batch_size, down_block.downsamplers[-1].conv.out_channels, out_h, out_w))
|
||||
|
||||
# mid_block
|
||||
additional_residuals_shapes.append(
|
||||
(batch_size, reference_unet.mid_block.resnets[-1].out_channels, out_h, out_w)
|
||||
)
|
||||
|
||||
baseline_sample_unet_inputs["down_block_additional_residuals"] = ()
|
||||
for i, shape in enumerate(additional_residuals_shapes):
|
||||
sample_residual_input = torch.rand(*shape)
|
||||
sample_unet_inputs[f"additional_residual_{i}"] = sample_residual_input
|
||||
if i == len(additional_residuals_shapes) - 1:
|
||||
baseline_sample_unet_inputs["mid_block_additional_residual"] = sample_residual_input
|
||||
else:
|
||||
baseline_sample_unet_inputs["down_block_additional_residuals"] += (sample_residual_input, )
|
||||
|
||||
sample_unet_inputs_spec = {
|
||||
k: (v.shape, v.dtype)
|
||||
for k, v in sample_unet_inputs.items()
|
||||
}
|
||||
logger.info(f"Sample UNet inputs spec: {sample_unet_inputs_spec}")
|
||||
|
||||
# JIT trace
|
||||
logger.info("JIT tracing..")
|
||||
reference_unet = torch.jit.trace(reference_unet,
|
||||
|
@ -624,7 +692,7 @@ def convert_unet(pipe, args):
|
|||
if args.check_output_correctness:
|
||||
baseline_out = pipe.unet(**baseline_sample_unet_inputs,
|
||||
return_dict=False)[0].numpy()
|
||||
reference_out = reference_unet(**sample_unet_inputs)[0].numpy()
|
||||
reference_out = reference_unet(*sample_unet_inputs.values())[0].numpy()
|
||||
report_correctness(baseline_out, reference_out,
|
||||
"unet baseline to reference PyTorch")
|
||||
|
||||
|
@ -636,7 +704,7 @@ def convert_unet(pipe, args):
|
|||
for k, v in sample_unet_inputs.items()
|
||||
}
|
||||
|
||||
coreml_unet, out_path = _convert_to_coreml("unet", reference_unet,
|
||||
coreml_unet, out_path = _convert_to_coreml(unet_name, reference_unet,
|
||||
coreml_sample_unet_inputs,
|
||||
["noise_pred"], args)
|
||||
del reference_unet
|
||||
|
@ -872,6 +940,174 @@ def convert_safety_checker(pipe, args):
|
|||
del traced_safety_checker, coreml_safety_checker, pipe.safety_checker
|
||||
gc.collect()
|
||||
|
||||
def convert_controlnet(pipe, args):
|
||||
""" Converts each ControlNet for Stable Diffusion
|
||||
"""
|
||||
if not hasattr(pipe, "unet"):
|
||||
raise RuntimeError(
|
||||
"convert_unet() deletes pipe.unet to save RAM. "
|
||||
"Please use convert_vae_encoder() before convert_unet()")
|
||||
|
||||
if not hasattr(pipe, "text_encoder"):
|
||||
raise RuntimeError(
|
||||
"convert_text_encoder() deletes pipe.text_encoder to save RAM. "
|
||||
"Please use convert_unet() before convert_text_encoder()")
|
||||
|
||||
if args.model_version != "runwayml/stable-diffusion-v1-5":
|
||||
logger.warning(
|
||||
"The original ControlNet models were trained using Stable Diffusion v1.5. "
|
||||
"It is possible that the converted model may not be compatible with controlnet.")
|
||||
|
||||
for i, controlnet_model_version in enumerate(args.convert_controlnet):
|
||||
controlnet_model_name = controlnet_model_version.replace("/", "_")
|
||||
fname = f"ControlNet_{controlnet_model_name}.mlpackage"
|
||||
out_path = os.path.join(args.o, fname)
|
||||
|
||||
if os.path.exists(out_path):
|
||||
logger.info(
|
||||
f"`controlnet_{controlnet_model_name}` already exists at {out_path}, skipping conversion."
|
||||
)
|
||||
continue
|
||||
|
||||
if i == 0:
|
||||
batch_size = 2 # for classifier-free guidance
|
||||
sample_shape = (
|
||||
batch_size, # B
|
||||
pipe.unet.config.in_channels, # C
|
||||
pipe.unet.config.sample_size, # H
|
||||
pipe.unet.config.sample_size, # W
|
||||
)
|
||||
|
||||
encoder_hidden_states_shape = (
|
||||
batch_size,
|
||||
pipe.text_encoder.config.hidden_size,
|
||||
1,
|
||||
pipe.text_encoder.config.max_position_embeddings,
|
||||
)
|
||||
|
||||
controlnet_cond_shape = (
|
||||
batch_size, # B
|
||||
3, # C
|
||||
(args.latent_h or pipe.unet.config.sample_size) * 8, # H
|
||||
(args.latent_w or pipe.unet.config.sample_size) * 8, # w
|
||||
)
|
||||
|
||||
# Create the scheduled timesteps for downstream use
|
||||
DEFAULT_NUM_INFERENCE_STEPS = 50
|
||||
pipe.scheduler.set_timesteps(DEFAULT_NUM_INFERENCE_STEPS)
|
||||
|
||||
# Prepare inputs
|
||||
sample_controlnet_inputs = OrderedDict([
|
||||
("sample", torch.rand(*sample_shape)),
|
||||
("timestep",
|
||||
torch.tensor([pipe.scheduler.timesteps[0].item()] *
|
||||
(batch_size)).to(torch.float32)),
|
||||
("encoder_hidden_states", torch.rand(*encoder_hidden_states_shape)),
|
||||
("controlnet_cond", torch.rand(*controlnet_cond_shape)),
|
||||
])
|
||||
sample_controlnet_inputs_spec = {
|
||||
k: (v.shape, v.dtype)
|
||||
for k, v in sample_controlnet_inputs.items()
|
||||
}
|
||||
logger.info(
|
||||
f"Sample ControlNet inputs spec: {sample_controlnet_inputs_spec}")
|
||||
|
||||
baseline_sample_controlnet_inputs = deepcopy(sample_controlnet_inputs)
|
||||
baseline_sample_controlnet_inputs[
|
||||
"encoder_hidden_states"] = baseline_sample_controlnet_inputs[
|
||||
"encoder_hidden_states"].squeeze(2).transpose(1, 2)
|
||||
|
||||
# Import controlnet model and initialize reference controlnet
|
||||
original_controlnet = ControlNetModel.from_pretrained(
|
||||
controlnet_model_version,
|
||||
use_auth_token=True
|
||||
)
|
||||
reference_controlnet = controlnet.ControlNetModel(**original_controlnet.config).eval()
|
||||
load_state_dict_summary = reference_controlnet.load_state_dict(
|
||||
original_controlnet.state_dict())
|
||||
|
||||
num_residuals = reference_controlnet.get_num_residuals()
|
||||
output_keys = [f"additional_residual_{i}" for i in range(num_residuals)]
|
||||
|
||||
# JIT trace
|
||||
logger.info("JIT tracing..")
|
||||
reference_controlnet = torch.jit.trace(reference_controlnet,
|
||||
list(sample_controlnet_inputs.values()))
|
||||
logger.info("Done.")
|
||||
|
||||
if args.check_output_correctness:
|
||||
baseline_out = original_controlnet(**baseline_sample_controlnet_inputs,
|
||||
return_dict=False)
|
||||
reference_out = reference_controlnet(*sample_controlnet_inputs.values())
|
||||
|
||||
baseline_down_residuals, baseline_mid_residuals = baseline_out
|
||||
baseline_out = baseline_down_residuals + (baseline_mid_residuals,)
|
||||
reference_down_residuals, reference_mid_residuals = reference_out
|
||||
reference_out = reference_down_residuals +(reference_mid_residuals,)
|
||||
|
||||
for key, b_out, r_out in zip(output_keys, baseline_out, reference_out):
|
||||
b_out = b_out.numpy()
|
||||
r_out = r_out.numpy()
|
||||
logger.info(f"Check {key} correctness")
|
||||
report_correctness(b_out, r_out,
|
||||
f"controlnet({controlnet_model_name}) baseline to reference PyTorch")
|
||||
|
||||
del original_controlnet
|
||||
gc.collect()
|
||||
|
||||
coreml_sample_controlnet_inputs = {
|
||||
k: v.numpy().astype(np.float16)
|
||||
for k, v in sample_controlnet_inputs.items()
|
||||
}
|
||||
|
||||
coreml_controlnet, out_path = _convert_to_coreml(f"controlnet_{controlnet_model_name}", reference_controlnet,
|
||||
coreml_sample_controlnet_inputs,
|
||||
output_keys, args,
|
||||
out_path=out_path)
|
||||
|
||||
del reference_controlnet
|
||||
gc.collect()
|
||||
|
||||
coreml_controlnet.author = f"Please refer to the Model Card available at huggingface.co/{controlnet_model_version}"
|
||||
coreml_controlnet.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
|
||||
coreml_controlnet.version = controlnet_model_version
|
||||
coreml_controlnet.short_description = \
|
||||
"ControlNet is a neural network structure to control diffusion models by adding extra conditions. " \
|
||||
"Please refer to https://arxiv.org/abs/2302.05543 for details."
|
||||
|
||||
# Set the input descriptions
|
||||
coreml_controlnet.input_description["sample"] = \
|
||||
"The low resolution latent feature maps being denoised through reverse diffusion"
|
||||
coreml_controlnet.input_description["timestep"] = \
|
||||
"A value emitted by the associated scheduler object to condition the model on a given noise schedule"
|
||||
coreml_controlnet.input_description["encoder_hidden_states"] = \
|
||||
"Output embeddings from the associated text_encoder model to condition to generated image on text. " \
|
||||
"A maximum of 77 tokens (~40 words) are allowed. Longer text is truncated. " \
|
||||
"Shorter text does not reduce computation."
|
||||
coreml_controlnet.input_description["controlnet_cond"] = \
|
||||
"An additional input image for ControlNet to condition the generated images."
|
||||
|
||||
# Set the output descriptions
|
||||
for i in range(num_residuals):
|
||||
coreml_controlnet.output_description[f"additional_residual_{i}"] = \
|
||||
"One of the outputs of each downsampling block in ControlNet. " \
|
||||
"The value added to the corresponding resnet output in UNet."
|
||||
|
||||
_save_mlpackage(coreml_controlnet, out_path)
|
||||
logger.info(f"Saved controlnet into {out_path}")
|
||||
|
||||
# Parity check PyTorch vs CoreML
|
||||
if args.check_output_correctness:
|
||||
coreml_out = coreml_controlnet.predict(coreml_sample_controlnet_inputs)
|
||||
for key, b_out in zip(output_keys, baseline_out):
|
||||
b_out = b_out.numpy()
|
||||
logger.info(f"Check {key} correctness")
|
||||
report_correctness(b_out, coreml_out[key],
|
||||
"controlnet baseline PyTorch to reference CoreML")
|
||||
|
||||
del coreml_controlnet
|
||||
gc.collect()
|
||||
|
||||
|
||||
def main(args):
|
||||
os.makedirs(args.o, exist_ok=True)
|
||||
|
@ -883,6 +1119,13 @@ def main(args):
|
|||
use_auth_token=True)
|
||||
logger.info("Done.")
|
||||
|
||||
# Register the selected attention implementation globally
|
||||
unet.ATTENTION_IMPLEMENTATION_IN_EFFECT = unet.AttentionImplementations[
|
||||
args.attention_implementation]
|
||||
logger.info(
|
||||
f"Attention implementation in effect: {unet.ATTENTION_IMPLEMENTATION_IN_EFFECT}"
|
||||
)
|
||||
|
||||
# Convert models
|
||||
if args.convert_vae_decoder:
|
||||
logger.info("Converting vae_decoder")
|
||||
|
@ -894,6 +1137,11 @@ def main(args):
|
|||
convert_vae_encoder(pipe, args)
|
||||
logger.info("Converted vae_encoder")
|
||||
|
||||
if args.convert_controlnet:
|
||||
logger.info("Converting controlnet")
|
||||
convert_controlnet(pipe, args)
|
||||
logger.info("Converted controlnet")
|
||||
|
||||
if args.convert_unet:
|
||||
logger.info("Converting unet")
|
||||
convert_unet(pipe, args)
|
||||
|
@ -930,6 +1178,14 @@ def parser_spec():
|
|||
parser.add_argument("--convert-vae-encoder", action="store_true")
|
||||
parser.add_argument("--convert-unet", action="store_true")
|
||||
parser.add_argument("--convert-safety-checker", action="store_true")
|
||||
parser.add_argument(
|
||||
"--convert-controlnet",
|
||||
nargs="*",
|
||||
type=str,
|
||||
help=
|
||||
"Converts a ControlNet model hosted on HuggingFace to coreML format. " \
|
||||
"To convert multiple models, provide their names separated by spaces.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-version",
|
||||
default="CompVis/stable-diffusion-v1-4",
|
||||
|
@ -989,6 +1245,13 @@ def parser_spec():
|
|||
("If specified, quantize 16-bits weights to 8-bits weights in-place for all models. "
|
||||
"Not recommended as the generated image quality degraded significantly after 8-bit weight quantization"
|
||||
))
|
||||
parser.add_argument(
|
||||
"--unet-support-controlnet",
|
||||
action="store_true",
|
||||
help=
|
||||
("If specified, enable unet to receive additional inputs from controlnet. "
|
||||
"Each input added to corresponding resnet output."
|
||||
))
|
||||
|
||||
# Swift CLI Resource Bundling
|
||||
parser.add_argument(
|
||||
|
|
|
@ -937,6 +937,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
sample,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
*additional_residuals,
|
||||
):
|
||||
# 0. Project (or look-up) time embeddings
|
||||
t_emb = self.time_proj(timestep)
|
||||
|
@ -965,10 +966,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
if additional_residuals:
|
||||
new_down_block_res_samples = ()
|
||||
for i, down_block_res_sample in enumerate(down_block_res_samples):
|
||||
down_block_res_sample = down_block_res_sample + additional_residuals[i]
|
||||
new_down_block_res_samples += (down_block_res_sample,)
|
||||
down_block_res_samples = new_down_block_res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
if additional_residuals:
|
||||
sample = sample + additional_residuals[-1]
|
||||
|
||||
# 5. up
|
||||
for upsample_block in self.up_blocks:
|
||||
|
@ -1081,3 +1092,13 @@ def get_up_block(
|
|||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
def calculate_conv2d_output_shape(in_h, in_w, conv2d_layer):
|
||||
k_h, k_w = conv2d_layer.kernel_size
|
||||
pad_h, pad_w = conv2d_layer.padding
|
||||
stride_h, stride_w = conv2d_layer.stride
|
||||
|
||||
out_h = math.floor((in_h + 2 * pad_h - k_h) / stride_h + 1)
|
||||
out_w = math.floor((in_w + 2 * pad_w - k_w) / stride_w + 1)
|
||||
|
||||
return out_h, out_w
|
||||
|
|
|
@ -63,8 +63,8 @@ extension CGImage {
|
|||
return cgImage
|
||||
}
|
||||
|
||||
public var plannerRGBShapedArray: MLShapedArray<Float32> {
|
||||
get throws {
|
||||
public func plannerRGBShapedArray(minValue: Float, maxValue: Float)
|
||||
throws -> MLShapedArray<Float32> {
|
||||
guard
|
||||
var sourceFormat = vImage_CGImageFormat(cgImage: self),
|
||||
var mediumFormat = vImage_CGImageFormat(
|
||||
|
@ -100,8 +100,8 @@ extension CGImage {
|
|||
var destinationG = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
|
||||
var destinationB = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
|
||||
|
||||
var minFloat: [Float] = [-1.0, -1.0, -1.0, -1.0]
|
||||
var maxFloat: [Float] = [1.0, 1.0, 1.0, 1.0]
|
||||
var minFloat: [Float] = Array(repeating: minValue, count: 4)
|
||||
var maxFloat: [Float] = Array(repeating: maxValue, count: 4)
|
||||
|
||||
vImageConvert_ARGB8888toPlanarF(&mediumDesination, &destinationA, &destinationR, &destinationG, &destinationB, &maxFloat, &minFloat, .zero)
|
||||
|
||||
|
@ -114,7 +114,6 @@ extension CGImage {
|
|||
let shapedArray = MLShapedArray<Float32>(data: imageData, shape: [1, 3, self.width, self.height])
|
||||
|
||||
return shapedArray
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
// For licensing see accompanying LICENSE.md file.
|
||||
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||
|
||||
import Foundation
|
||||
import CoreML
|
||||
|
||||
@available(iOS 16.2, macOS 13.1, *)
|
||||
public struct ControlNet: ResourceManaging {
|
||||
|
||||
var models: [ManagedMLModel]
|
||||
|
||||
public init(modelAt urls: [URL],
|
||||
configuration: MLModelConfiguration) {
|
||||
self.models = urls.map { ManagedMLModel(modelAt: $0, configuration: configuration) }
|
||||
}
|
||||
|
||||
/// Load resources.
|
||||
public func loadResources() throws {
|
||||
for model in models {
|
||||
try model.loadResources()
|
||||
}
|
||||
}
|
||||
|
||||
/// Unload the underlying model to free up memory
|
||||
public func unloadResources() {
|
||||
for model in models {
|
||||
model.unloadResources()
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-warm resources
|
||||
public func prewarmResources() throws {
|
||||
// Override default to pre-warm each model
|
||||
for model in models {
|
||||
try model.loadResources()
|
||||
model.unloadResources()
|
||||
}
|
||||
}
|
||||
|
||||
var inputImageDescriptions: [MLFeatureDescription] {
|
||||
models.map { model in
|
||||
try! model.perform {
|
||||
$0.modelDescription.inputDescriptionsByName["controlnet_cond"]!
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The expected shape of the models image input
|
||||
public var inputImageShapes: [[Int]] {
|
||||
inputImageDescriptions.map { desc in
|
||||
desc.multiArrayConstraint!.shape.map { $0.intValue }
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate additional inputs for Unet to generate intended image following provided images
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - latents: Batch of latent samples in an array
|
||||
/// - timeStep: Current diffusion timestep
|
||||
/// - hiddenStates: Hidden state to condition on
|
||||
/// - images: Images for each ControlNet
|
||||
/// - Returns: Array of predicted noise residuals
|
||||
func execute(
|
||||
latents: [MLShapedArray<Float32>],
|
||||
timeStep: Int,
|
||||
hiddenStates: MLShapedArray<Float32>,
|
||||
images: [MLShapedArray<Float32>]
|
||||
) throws -> [[String: MLShapedArray<Float32>]] {
|
||||
// Match time step batch dimension to the model / latent samples
|
||||
let t = MLShapedArray(scalars: [Float(timeStep), Float(timeStep)], shape: [2])
|
||||
|
||||
var outputs: [[String: MLShapedArray<Float32>]] = []
|
||||
|
||||
for (modelIndex, model) in models.enumerated() {
|
||||
let inputs = try latents.map { latent in
|
||||
let dict: [String: Any] = [
|
||||
"sample": MLMultiArray(latent),
|
||||
"timestep": MLMultiArray(t),
|
||||
"encoder_hidden_states": MLMultiArray(hiddenStates),
|
||||
"controlnet_cond": MLMultiArray(images[modelIndex])
|
||||
]
|
||||
return try MLDictionaryFeatureProvider(dictionary: dict)
|
||||
}
|
||||
|
||||
let batch = MLArrayBatchProvider(array: inputs)
|
||||
|
||||
let results = try model.perform {
|
||||
try $0.predictions(fromBatch: batch)
|
||||
}
|
||||
|
||||
// pre-allocate MLShapedArray with a specific shape in outputs
|
||||
if outputs.isEmpty {
|
||||
outputs = initOutputs(
|
||||
batch: latents.count,
|
||||
shapes: results.features(at: 0).featureValueDictionary
|
||||
)
|
||||
}
|
||||
|
||||
for n in 0..<results.count {
|
||||
let result = results.features(at: n)
|
||||
for k in result.featureNames {
|
||||
let newValue = result.featureValue(for: k)!.shapedArrayValue(of: Float32.self)!
|
||||
if modelIndex == 0 {
|
||||
outputs[n][k] = newValue
|
||||
} else {
|
||||
outputs[n][k]!.withUnsafeMutableShapedBufferPointer { pt, _, _ in
|
||||
for (i, v) in newValue.scalars.enumerated() { pt[i] += v }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
private func initOutputs(batch: Int, shapes: [String: MLFeatureValue]) -> [[String: MLShapedArray<Float32>]] {
|
||||
var output: [String: MLShapedArray<Float32>] = [:]
|
||||
for (outputName, featureValue) in shapes {
|
||||
output[outputName] = MLShapedArray<Float32>(
|
||||
repeating: 0.0,
|
||||
shape: featureValue.multiArrayValue!.shape.map { $0.intValue }
|
||||
)
|
||||
}
|
||||
return Array(repeating: output, count: batch)
|
||||
}
|
||||
}
|
|
@ -50,7 +50,7 @@ public struct Encoder: ResourceManaging {
|
|||
scaleFactor: Float32,
|
||||
random: inout RandomSource
|
||||
) throws -> MLShapedArray<Float32> {
|
||||
let imageData = try image.plannerRGBShapedArray
|
||||
let imageData = try image.plannerRGBShapedArray(minValue: -1.0, maxValue: 1.0)
|
||||
guard imageData.shape == inputShape else {
|
||||
// TODO: Consider auto resizing and croping similar to how Vision or CoreML auto-generated Swift code can accomplish with `MLFeatureValue`
|
||||
throw Error.sampleInputShapeNotCorrect
|
||||
|
|
|
@ -18,6 +18,10 @@ public extension StableDiffusionPipeline {
|
|||
public let safetyCheckerURL: URL
|
||||
public let vocabURL: URL
|
||||
public let mergesURL: URL
|
||||
public let controlNetDirURL: URL
|
||||
public let controlledUnetURL: URL
|
||||
public let controlledUnetChunk1URL: URL
|
||||
public let controlledUnetChunk2URL: URL
|
||||
|
||||
public init(resourcesAt baseURL: URL) {
|
||||
textEncoderURL = baseURL.appending(path: "TextEncoder.mlmodelc")
|
||||
|
@ -29,6 +33,10 @@ public extension StableDiffusionPipeline {
|
|||
safetyCheckerURL = baseURL.appending(path: "SafetyChecker.mlmodelc")
|
||||
vocabURL = baseURL.appending(path: "vocab.json")
|
||||
mergesURL = baseURL.appending(path: "merges.txt")
|
||||
controlNetDirURL = baseURL.appending(path: "controlnet")
|
||||
controlledUnetURL = baseURL.appending(path: "ControlledUnet.mlmodelc")
|
||||
controlledUnetChunk1URL = baseURL.appending(path: "ControlledUnetChunk1.mlmodelc")
|
||||
controlledUnetChunk2URL = baseURL.appending(path: "ControlledUnetChunk2.mlmodelc")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -38,12 +46,14 @@ public extension StableDiffusionPipeline {
|
|||
/// - Parameters:
|
||||
/// - baseURL: URL pointing to directory holding all model
|
||||
/// and tokenization resources
|
||||
/// - controlNetModelNames: Specify ControlNet models to use in generation
|
||||
/// - configuration: The configuration to load model resources with
|
||||
/// - disableSafety: Load time disable of safety to save memory
|
||||
/// - reduceMemory: Setup pipeline in reduced memory mode
|
||||
/// - Returns:
|
||||
/// Pipeline ready for image generation if all necessary resources loaded
|
||||
init(resourcesAt baseURL: URL,
|
||||
controlNet controlNetModelNames: [String],
|
||||
configuration config: MLModelConfiguration = .init(),
|
||||
disableSafety: Bool = false,
|
||||
reduceMemory: Bool = false) throws {
|
||||
|
@ -56,15 +66,38 @@ public extension StableDiffusionPipeline {
|
|||
let textEncoder = TextEncoder(tokenizer: tokenizer,
|
||||
modelAt: urls.textEncoderURL,
|
||||
configuration: config)
|
||||
|
||||
// ControlNet model
|
||||
var controlNet: ControlNet? = nil
|
||||
let controlNetURLs = controlNetModelNames.map { model in
|
||||
let fileName = model + ".mlmodelc"
|
||||
return urls.controlNetDirURL.appending(path: fileName)
|
||||
}
|
||||
if !controlNetURLs.isEmpty {
|
||||
controlNet = ControlNet(modelAt: controlNetURLs, configuration: config)
|
||||
}
|
||||
|
||||
// Unet model
|
||||
let unet: Unet
|
||||
if FileManager.default.fileExists(atPath: urls.unetChunk1URL.path) &&
|
||||
FileManager.default.fileExists(atPath: urls.unetChunk2URL.path) {
|
||||
unet = Unet(chunksAt: [urls.unetChunk1URL, urls.unetChunk2URL],
|
||||
let unetURL: URL, unetChunk1URL: URL, unetChunk2URL: URL
|
||||
|
||||
// if ControlNet available, Unet supports additional inputs from ControlNet
|
||||
if controlNet == nil {
|
||||
unetURL = urls.unetURL
|
||||
unetChunk1URL = urls.unetChunk1URL
|
||||
unetChunk2URL = urls.unetChunk2URL
|
||||
} else {
|
||||
unetURL = urls.controlledUnetURL
|
||||
unetChunk1URL = urls.controlledUnetChunk1URL
|
||||
unetChunk2URL = urls.controlledUnetChunk2URL
|
||||
}
|
||||
|
||||
if FileManager.default.fileExists(atPath: unetChunk1URL.path) &&
|
||||
FileManager.default.fileExists(atPath: unetChunk2URL.path) {
|
||||
unet = Unet(chunksAt: [unetChunk1URL, unetChunk2URL],
|
||||
configuration: config)
|
||||
} else {
|
||||
unet = Unet(modelAt: urls.unetURL, configuration: config)
|
||||
unet = Unet(modelAt: unetURL, configuration: config)
|
||||
}
|
||||
|
||||
// Image Decoder
|
||||
|
@ -90,6 +123,7 @@ public extension StableDiffusionPipeline {
|
|||
unet: unet,
|
||||
decoder: decoder,
|
||||
encoder: encoder,
|
||||
controlNet: controlNet,
|
||||
safetyChecker: safetyChecker,
|
||||
reduceMemory: reduceMemory)
|
||||
}
|
||||
|
|
|
@ -33,6 +33,8 @@ extension StableDiffusionPipeline {
|
|||
public var seed: UInt32 = 0
|
||||
/// Controls the influence of the text prompt on sampling process (0=random images)
|
||||
public var guidanceScale: Float = 7.5
|
||||
/// List of Images for available ControlNet Models
|
||||
public var controlNetInputs: [CGImage] = []
|
||||
/// Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
|
||||
public var disableSafety: Bool = false
|
||||
/// The type of Scheduler to use.
|
||||
|
|
|
@ -47,6 +47,9 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|||
|
||||
/// Optional model for checking safety of generated image
|
||||
var safetyChecker: SafetyChecker? = nil
|
||||
|
||||
/// Optional model used before Unet to control generated images by additonal inputs
|
||||
var controlNet: ControlNet? = nil
|
||||
|
||||
/// Reports whether this pipeline can perform safety checks
|
||||
public var canSafetyCheck: Bool {
|
||||
|
@ -67,6 +70,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|||
/// - textEncoder: Model for encoding tokenized text
|
||||
/// - unet: Model for noise prediction on latent samples
|
||||
/// - decoder: Model for decoding latent sample to image
|
||||
/// - controlNet: Optional model to control generated images by additonal inputs
|
||||
/// - safetyChecker: Optional model for checking safety of generated images
|
||||
/// - reduceMemory: Option to enable reduced memory mode
|
||||
/// - Returns: Pipeline ready for image generation
|
||||
|
@ -74,12 +78,14 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|||
unet: Unet,
|
||||
decoder: Decoder,
|
||||
encoder: Encoder?,
|
||||
controlNet: ControlNet? = nil,
|
||||
safetyChecker: SafetyChecker? = nil,
|
||||
reduceMemory: Bool = false) {
|
||||
self.textEncoder = textEncoder
|
||||
self.unet = unet
|
||||
self.decoder = decoder
|
||||
self.encoder = encoder
|
||||
self.controlNet = controlNet
|
||||
self.safetyChecker = safetyChecker
|
||||
self.reduceMemory = reduceMemory
|
||||
}
|
||||
|
@ -95,6 +101,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|||
try textEncoder.loadResources()
|
||||
try unet.loadResources()
|
||||
try decoder.loadResources()
|
||||
try controlNet?.loadResources()
|
||||
try safetyChecker?.loadResources()
|
||||
}
|
||||
}
|
||||
|
@ -104,6 +111,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|||
textEncoder.unloadResources()
|
||||
unet.unloadResources()
|
||||
decoder.unloadResources()
|
||||
controlNet?.unloadResources()
|
||||
safetyChecker?.unloadResources()
|
||||
}
|
||||
|
||||
|
@ -112,6 +120,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|||
try textEncoder.prewarmResources()
|
||||
try unet.prewarmResources()
|
||||
try decoder.prewarmResources()
|
||||
try controlNet?.prewarmResources()
|
||||
try safetyChecker?.prewarmResources()
|
||||
}
|
||||
|
||||
|
@ -154,6 +163,15 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|||
// Generate random latent samples from specified seed
|
||||
var latents: [MLShapedArray<Float32>] = try generateLatentSamples(configuration: config, scheduler: scheduler[0])
|
||||
let timestepStrength: Float? = config.mode == .imageToImage ? config.strength : nil
|
||||
|
||||
// Convert cgImage for ControlNet into MLShapedArray
|
||||
let controlNetConds = try config.controlNetInputs.map { cgImage in
|
||||
let shapedArray = try cgImage.plannerRGBShapedArray(minValue: 0.0, maxValue: 1.0)
|
||||
return MLShapedArray(
|
||||
concatenating: [shapedArray, shapedArray],
|
||||
alongAxis: 0
|
||||
)
|
||||
}
|
||||
|
||||
// De-noising loop
|
||||
let timeSteps: [Int] = scheduler[0].calculateTimesteps(strength: timestepStrength)
|
||||
|
@ -164,13 +182,22 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|||
let latentUnetInput = latents.map {
|
||||
MLShapedArray<Float32>(concatenating: [$0, $0], alongAxis: 0)
|
||||
}
|
||||
|
||||
|
||||
// Before Unet, execute controlNet and add the output into Unet inputs
|
||||
let additionalResiduals = try controlNet?.execute(
|
||||
latents: latentUnetInput,
|
||||
timeStep: t,
|
||||
hiddenStates: hiddenStates,
|
||||
images: controlNetConds
|
||||
)
|
||||
|
||||
// Predict noise residuals from latent samples
|
||||
// and current time step conditioned on hidden states
|
||||
var noise = try unet.predictNoise(
|
||||
latents: latentUnetInput,
|
||||
timeStep: t,
|
||||
hiddenStates: hiddenStates
|
||||
hiddenStates: hiddenStates,
|
||||
additionalResiduals: additionalResiduals
|
||||
)
|
||||
|
||||
noise = performGuidance(noise, config.guidanceScale)
|
||||
|
@ -201,6 +228,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
|
|||
}
|
||||
|
||||
if reduceMemory {
|
||||
controlNet?.unloadResources()
|
||||
unet.unloadResources()
|
||||
}
|
||||
|
||||
|
|
|
@ -79,19 +79,25 @@ public struct Unet: ResourceManaging {
|
|||
func predictNoise(
|
||||
latents: [MLShapedArray<Float32>],
|
||||
timeStep: Int,
|
||||
hiddenStates: MLShapedArray<Float32>
|
||||
hiddenStates: MLShapedArray<Float32>,
|
||||
additionalResiduals: [[String: MLShapedArray<Float32>]]? = nil
|
||||
) throws -> [MLShapedArray<Float32>] {
|
||||
|
||||
// Match time step batch dimension to the model / latent samples
|
||||
let t = MLShapedArray<Float32>(scalars:[Float(timeStep), Float(timeStep)],shape:[2])
|
||||
|
||||
// Form batch input to model
|
||||
let inputs = try latents.map {
|
||||
let dict: [String: Any] = [
|
||||
"sample" : MLMultiArray($0),
|
||||
let inputs = try latents.enumerated().map {
|
||||
var dict: [String: Any] = [
|
||||
"sample" : MLMultiArray($0.element),
|
||||
"timestep" : MLMultiArray(t),
|
||||
"encoder_hidden_states": MLMultiArray(hiddenStates)
|
||||
]
|
||||
if let residuals = additionalResiduals?[$0.offset] {
|
||||
for (k, v) in residuals {
|
||||
dict[k] = MLMultiArray(v)
|
||||
}
|
||||
}
|
||||
return try MLDictionaryFeatureProvider(dictionary: dict)
|
||||
}
|
||||
let batch = MLArrayBatchProvider(array: inputs)
|
||||
|
|
|
@ -8,6 +8,7 @@ import Foundation
|
|||
import StableDiffusion
|
||||
import UniformTypeIdentifiers
|
||||
import Cocoa
|
||||
import CoreImage
|
||||
|
||||
@available(iOS 16.2, macOS 13.1, *)
|
||||
struct StableDiffusionSample: ParsableCommand {
|
||||
|
@ -71,6 +72,18 @@ struct StableDiffusionSample: ParsableCommand {
|
|||
|
||||
@Option(help: "Random number generator to use, one of {numpy, torch}")
|
||||
var rng: RNGOption = .numpy
|
||||
|
||||
@Option(
|
||||
parsing: .upToNextOption,
|
||||
help: "ControlNet models used in image generation (enter file names in Resources/controlnet without extension)"
|
||||
)
|
||||
var controlnet: [String] = []
|
||||
|
||||
@Option(
|
||||
parsing: .upToNextOption,
|
||||
help: "image for each controlNet model (corresponding to the same order as --controlnet)"
|
||||
)
|
||||
var controlnetInputs: [String] = []
|
||||
|
||||
@Flag(help: "Disable safety checking")
|
||||
var disableSafety: Bool = false
|
||||
|
@ -90,6 +103,7 @@ struct StableDiffusionSample: ParsableCommand {
|
|||
log("Loading resources and creating pipeline\n")
|
||||
log("(Note: This can take a while the first time using these resources)\n")
|
||||
let pipeline = try StableDiffusionPipeline(resourcesAt: resourceURL,
|
||||
controlNet: controlnet,
|
||||
configuration: config,
|
||||
disableSafety: disableSafety,
|
||||
reduceMemory: reduceMemory)
|
||||
|
@ -99,14 +113,7 @@ struct StableDiffusionSample: ParsableCommand {
|
|||
if let image {
|
||||
let imageURL = URL(filePath: image)
|
||||
do {
|
||||
let imageData = try Data(contentsOf: imageURL)
|
||||
guard
|
||||
let nsImage = NSImage(data: imageData),
|
||||
let loadedImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil)
|
||||
else {
|
||||
throw RunError.resources("Starting Image not available \(resourcePath)")
|
||||
}
|
||||
startingImage = loadedImage
|
||||
startingImage = try convertImageToCGImage(imageURL: imageURL)
|
||||
} catch let error {
|
||||
throw RunError.resources("Starting image not found \(imageURL), error: \(error)")
|
||||
}
|
||||
|
@ -114,6 +121,21 @@ struct StableDiffusionSample: ParsableCommand {
|
|||
} else {
|
||||
startingImage = nil
|
||||
}
|
||||
|
||||
// convert image for ControlNet into CGImage when controlNet available
|
||||
let controlNetInputs: [CGImage]
|
||||
if !controlnet.isEmpty {
|
||||
controlNetInputs = try controlnetInputs.map { imagePath in
|
||||
let imageURL = URL(filePath: imagePath)
|
||||
do {
|
||||
return try convertImageToCGImage(imageURL: imageURL)
|
||||
} catch let error {
|
||||
throw RunError.resources("Image for ControlNet not found \(imageURL), error: \(error)")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
controlNetInputs = []
|
||||
}
|
||||
|
||||
log("Sampling ...\n")
|
||||
let sampleTimer = SampleTimer()
|
||||
|
@ -127,6 +149,7 @@ struct StableDiffusionSample: ParsableCommand {
|
|||
pipelineConfig.imageCount = imageCount
|
||||
pipelineConfig.stepCount = stepCount
|
||||
pipelineConfig.seed = seed
|
||||
pipelineConfig.controlNetInputs = controlNetInputs
|
||||
pipelineConfig.guidanceScale = guidanceScale
|
||||
pipelineConfig.schedulerType = scheduler.stableDiffusionScheduler
|
||||
pipelineConfig.rngType = rng.stableDiffusionRNG
|
||||
|
@ -144,6 +167,17 @@ struct StableDiffusionSample: ParsableCommand {
|
|||
|
||||
_ = try saveImages(images, logNames: true)
|
||||
}
|
||||
|
||||
func convertImageToCGImage(imageURL: URL) throws -> CGImage {
|
||||
let imageData = try Data(contentsOf: imageURL)
|
||||
guard
|
||||
let nsImage = NSImage(data: imageData),
|
||||
let loadedImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil)
|
||||
else {
|
||||
throw RunError.resources("Image not available \(resourcePath)")
|
||||
}
|
||||
return loadedImage
|
||||
}
|
||||
|
||||
func handleProgress(
|
||||
_ progress: StableDiffusionPipeline.Progress,
|
||||
|
|
Loading…
Reference in New Issue