Support ControlNet (#153)

* add controlnet tentatively

* add controlnet in python code

* implement swift part

* support 8-bit quantization

* add controlnet unload when reduce memory

* remove irrelevant changes

* add more description about controlnet option in swift

* fix some for pr and update README

* pre-allocate zero shapedArray + make multi-controlnet faster
This commit is contained in:
Chimme 2023-04-19 08:17:59 +09:00 committed by GitHub
parent d1a6888d43
commit 7f65e1c84b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 966 additions and 66 deletions

View File

@ -139,6 +139,10 @@ This generally takes 15-20 minutes on an M1 MacBook Pro. Upon successful executi
- `--check-output-correctness`: Compares original PyTorch model's outputs to final Core ML model's outputs. This flag increases RAM consumption significantly so it is recommended only for debugging purposes.
- `--convert-controlnet`: Converts ControlNet models specified after this option. This can also convert multiple models if you specify like `--convert-controlnet lllyasviel/sd-controlnet-mlsd lllyasviel/sd-controlnet-depth`.
- `--unet-support-controlnet`: enables a converted UNet model to receive additional inputs from ControlNet. This is required for generating image with using ControlNet and saved with a different name, `*_control-unet.mlpackage`, distinct from normal UNet. On the other hand, this UNet model can not work without ControlNet. Please use normal UNet for just txt2img.
</details>
## <a name="image-generation-with-python"></a> Image Generation with Python
@ -157,6 +161,8 @@ Please refer to the help menu for all available arguments: `python -m python_cor
- `--model-version`: If you overrode the default model version while converting models to Core ML, you will need to specify the same model version here.
- `--compute-unit`: Note that the most performant compute unit for this particular implementation may differ across different hardware. `CPU_AND_GPU` or `CPU_AND_NE` may be faster than `ALL`. Please refer to the [Performance Benchmark](#performance-benchmark) section for further guidance.
- `--scheduler`: If you would like to experiment with different schedulers, you may specify it here. For available options, please see the help menu. You may also specify a custom number of inference steps by `--num-inference-steps` which defaults to 50.
- `--controlnet`: ControlNet models specified with this option are used in image generation. Use this option in the format `--controlnet lllyasviel/sd-controlnet-mlsd lllyasviel/sd-controlnet-depth` and make sure to use `--controlnet-inputs` in conjunction.
- `--controlnet-inputs`: Image inputs corresponding to each ControlNet model. Please provide image paths in same order as models in `--controlnet`, for example: `--controlnet-inputs image_mlsd image_depth`.
</details>
@ -228,6 +234,14 @@ Optionally, it may also include the safety checker model that some versions of S
- `SafetyChecker.mlmodelc`
Optionally, for ControlNet:
- `ControlledUNet.mlmodelc` or `ControlledUnetChunk1.mlmodelc` & `ControlledUnetChunk2.mlmodelc` (enabled to receive ControlNet values)
- `controlnet/` (directory containing ControlNet models)
- `LllyasvielSdControlnetMlsd.mlmodelc` (for example, from lllyasviel/sd-controlnet-mlsd)
- `LllyasvielSdControlnetDepth.mlmodelc` (for example, from lllyasviel/sd-controlnet-depth)
- Other models you converted
Note that the chunked version of Unet is checked for first. Only if it is not present will the full `Unet.mlmodelc` be loaded. Chunking is required for iOS and iPadOS and not necessary for macOS.
</details>

View File

@ -0,0 +1,244 @@
#
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers import ModelMixin
import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet import Timesteps, TimestepEmbedding, get_down_block, UNetMidBlock2DCrossAttn, linear_to_conv2d_map
class ControlNetConditioningEmbedding(nn.Module):
def __init__(
self,
conditioning_embedding_channels,
conditioning_channels=3,
block_out_channels=(16, 32, 96, 256),
):
super().__init__()
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
class ControlNetModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels=4,
flip_sin_to_cos=True,
freq_shift=0,
down_block_types=(
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
only_cross_attention=False,
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
downsample_padding=1,
mid_block_scale_factor=1,
act_fn="silu",
norm_num_groups=32,
norm_eps=1e-5,
cross_attention_dim=1280,
attention_head_dim=8,
use_linear_projection=False,
upcast_attention=False,
resnet_time_scale_shift="default",
conditioning_embedding_out_channels=(16, 32, 96, 256),
**kwargs,
):
super().__init__()
# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
)
# control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
)
self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
self.controlnet_down_blocks.append(controlnet_block)
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
)
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
self.controlnet_down_blocks.append(controlnet_block)
if not is_final_block:
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
self.controlnet_down_blocks.append(controlnet_block)
# mid
mid_block_channel = block_out_channels[-1]
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
def get_num_residuals(self):
num_res = 2 # initial sample + mid block
for down_block in self.down_blocks:
num_res += len(down_block.resnets)
if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None:
num_res += len(down_block.downsamplers)
return num_res
def forward(
self,
sample,
timestep,
encoder_hidden_states,
controlnet_cond,
):
# 1. time
t_emb = self.time_proj(timestep)
emb = self.time_embedding(t_emb)
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample += controlnet_cond
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
)
# 5. Control net blocks
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples += (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
return down_block_res_samples, mid_block_res_sample

View File

@ -98,5 +98,22 @@ def _load_mlpackage(submodule_name, mlpackages_dir, model_version,
return CoreMLModel(mlpackage_path, compute_unit)
def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
"""
model_name = model_version.replace("/", "_")
logger.info(f"Loading controlnet_{model_name} mlpackage")
fname = f"ControlNet_{model_name}.mlpackage"
mlpackage_path = os.path.join(mlpackages_dir, fname)
if not os.path.exists(mlpackage_path):
raise FileNotFoundError(
f"controlnet_{model_name} CoreML model doesn't exist at {mlpackage_path}")
return CoreMLModel(mlpackage_path, compute_unit)
def get_available_compute_units():
return tuple(cu for cu in ct.ComputeUnit._member_names_)

View File

@ -32,6 +32,7 @@ import os
from python_coreml_stable_diffusion.coreml_model import (
CoreMLModel,
_load_mlpackage,
_load_mlpackage_controlnet,
get_available_compute_units,
)
@ -39,6 +40,7 @@ import time
import torch # Only used for `torch.from_tensor` in `pipe.scheduler.step()`
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from typing import List, Optional, Union
from PIL import Image
class CoreMLStableDiffusionPipeline(DiffusionPipeline):
@ -60,6 +62,7 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
LMSDiscreteScheduler,
PNDMScheduler],
tokenizer: CLIPTokenizer,
controlnet: Optional[List[CoreMLModel]],
):
super().__init__()
@ -88,6 +91,8 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
self.unet = unet
self.unet.in_channels = self.unet.expected_inputs["sample"]["shape"][1]
self.controlnet = controlnet
self.vae_decoder = vae_decoder
VAE_DECODER_UPSAMPLE_FACTOR = 8
@ -168,6 +173,33 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
return text_embeddings
def run_controlnet(self,
sample,
timestep,
encoder_hidden_states,
controlnet_cond,
output_dtype=np.float16):
if not self.controlnet:
raise ValueError(
"Conditions for controlnet are given but the pipeline has no controlnet modules")
for i, (module, cond) in enumerate(zip(self.controlnet, controlnet_cond)):
module_outputs = module(
sample=sample.astype(np.float16),
timestep=timestep.astype(np.float16),
encoder_hidden_states=encoder_hidden_states.astype(np.float16),
controlnet_cond=cond.astype(np.float16),
)
if i == 0:
outputs = module_outputs
else:
for key in outputs.keys():
outputs[key] += module_outputs[key]
outputs = {k: v.astype(output_dtype) for k, v in outputs.items()}
return outputs
def run_safety_checker(self, image):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(
@ -222,6 +254,19 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
return latents
def prepare_control_cond(self,
controlnet_cond,
do_classifier_free_guidance,
batch_size,
num_images_per_prompt):
processed_cond_list = []
for cond in controlnet_cond:
cond = np.stack([cond] * batch_size * num_images_per_prompt)
if do_classifier_free_guidance:
cond = np.concatenate([cond] * 2)
processed_cond_list.append(cond)
return processed_cond_list
def check_inputs(self, prompt, height, width, callback_steps):
if height != self.height or width != self.width:
logger.warning(
@ -276,6 +321,7 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
return_dict=True,
callback=None,
callback_steps=1,
controlnet_cond=None,
**kwargs,
):
# 1. Check inputs. Raise error if not correct
@ -305,7 +351,7 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
# 5. Prepare latent variables and controlnet cond
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
@ -315,6 +361,14 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
latents,
)
if controlnet_cond:
controlnet_cond = self.prepare_control_cond(
controlnet_cond,
do_classifier_free_guidance,
batch_size,
num_images_per_prompt,
)
# 6. Prepare extra step kwargs
extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
@ -326,11 +380,23 @@ class CoreMLStableDiffusionPipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t)
# controlnet
if controlnet_cond:
additional_residuals = self.run_controlnet(
sample=latent_model_input,
timestep=np.array([t, t]),
encoder_hidden_states=text_embeddings,
controlnet_cond=controlnet_cond,
)
else:
additional_residuals = {}
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input.astype(np.float16),
timestep=np.array([t, t], np.float16),
encoder_hidden_states=text_embeddings.astype(np.float16),
**additional_residuals,
)["noise_pred"]
# perform guidance
@ -385,7 +451,8 @@ def get_coreml_pipe(pytorch_pipe,
model_version,
compute_unit,
delete_original_pipe=True,
scheduler_override=None):
scheduler_override=None,
controlnet_models=None):
""" Initializes and returns a `CoreMLStableDiffusionPipeline` from an original
diffusers PyTorch pipeline
"""
@ -417,6 +484,22 @@ def get_coreml_pipe(pytorch_pipe,
gc.collect()
logger.info("Removed PyTorch pipe to reduce peak memory consumption")
if controlnet_models:
model_names_to_load.remove("unet")
coreml_pipe_kwargs["unet"] = _load_mlpackage(
"control-unet",
mlpackages_dir,
model_version,
compute_unit,
)
coreml_pipe_kwargs["controlnet"] = [_load_mlpackage_controlnet(
mlpackages_dir,
model_version,
compute_unit,
) for model_version in controlnet_models]
else:
coreml_pipe_kwargs["controlnet"] = None
# Load Core ML models
logger.info(f"Loading Core ML models in memory from {mlpackages_dir}")
coreml_pipe_kwargs.update({
@ -453,6 +536,11 @@ def get_image_path(args, **override_kwargs):
return os.path.join(out_folder, out_fname + ".png")
def prepare_controlnet_cond(image_path, height, width):
image = Image.open(image_path).convert("RGB")
image = image.resize((height, width), resample=Image.LANCZOS)
image = np.array(image).transpose(2, 0, 1) / 255.0
return image
def main(args):
logger.info(f"Setting random seed to {args.seed}")
@ -472,7 +560,17 @@ def main(args):
mlpackages_dir=args.i,
model_version=args.model_version,
compute_unit=args.compute_unit,
scheduler_override=user_specified_scheduler)
scheduler_override=user_specified_scheduler,
controlnet_models=args.controlnet)
if args.controlnet:
controlnet_cond = []
for i, _ in enumerate(args.controlnet):
image_path = args.controlnet_inputs[i]
image = prepare_controlnet_cond(image_path, coreml_pipe.height, coreml_pipe.width)
controlnet_cond.append(image)
else:
controlnet_cond = None
logger.info("Beginning image generation.")
image = coreml_pipe(
@ -480,7 +578,8 @@ def main(args):
height=coreml_pipe.height,
width=coreml_pipe.width,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale
guidance_scale=args.guidance_scale,
controlnet_cond=controlnet_cond,
)
out_path = get_image_path(args)
@ -535,6 +634,18 @@ if __name__ == "__main__":
default=7.5,
type=float,
help="Controls the influence of the text prompt on sampling process (0=random images)")
parser.add_argument(
"--controlnet",
nargs="*",
type=str,
help=("Enables ControlNet and use control-unet instead of unet for additional inputs. "
"For Multi-Controlnet, provide the model names separated by spaces."))
parser.add_argument(
"--controlnet-inputs",
nargs="*",
type=str,
help=("Image paths for ControlNet inputs. "
"Please enter images corresponding to each controlnet provided at --controlnet option in same order."))
args = parser.parse_args()
main(args)

View File

@ -3,13 +3,13 @@
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
from python_coreml_stable_diffusion import unet
from python_coreml_stable_diffusion import unet, controlnet
import argparse
from collections import OrderedDict, defaultdict
from copy import deepcopy
import coremltools as ct
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionPipeline, ControlNetModel
import gc
import logging
@ -24,6 +24,8 @@ from python_coreml_stable_diffusion import chunk_mlprogram
import requests
import shutil
import time
import re
import pathlib
import torch
import torch.nn as nn
@ -84,7 +86,6 @@ def report_correctness(original_outputs, final_outputs, log_prefix):
)
return final_psnr
def _get_out_path(args, submodule_name):
fname = f"Stable_Diffusion_version_{args.model_version}_{submodule_name}.mlpackage"
fname = fname.replace("/", "_")
@ -100,8 +101,10 @@ def _save_mlpackage(model, output_path):
def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
output_names, args):
out_path = _get_out_path(args, submodule_name)
output_names, args, out_path=None):
if out_path is None:
out_path = _get_out_path(args, submodule_name)
if os.path.exists(out_path):
logger.info(f"Skipping export because {out_path} already exists")
@ -140,21 +143,32 @@ def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
def quantize_weights_to_8bits(args):
for model_name in [
"text_encoder", "vae_decoder", "vae_encoder", "unet", "unet_chunk1",
"unet_chunk2", "safety_checker"
"text_encoder", "vae_decoder", "vae_encoder", "unet", "unet_chunk1", "unet_chunk2",
"control-unet", "control-unet_chunk1", "control-unet_chunk2", "safety_checker"
]:
out_path = _get_out_path(args, model_name)
if os.path.exists(out_path):
logger.info(f"Quantizing {model_name}")
mlmodel = ct.models.MLModel(out_path,
compute_units=ct.ComputeUnit.CPU_ONLY)
mlmodel = ct.compression_utils.affine_quantize_weights(
mlmodel, mode="linear")
mlmodel.save(out_path)
logger.info("Done")
else:
logger.info(
f"Skipped quantizing {model_name} (Not found at {out_path})")
_quantize_and_save_8bits_model(out_path, model_name)
if args.convert_controlnet:
for controlnet_model_version in args.convert_controlnet:
controlnet_model_name = controlnet_model_version.replace("/", "_")
fname = f"ControlNet_{controlnet_model_name}.mlpackage"
out_path = os.path.join(args.o, fname)
_quantize_and_save_8bits_model(out_path, controlnet_model_name)
def _quantize_and_save_8bits_model(out_path, model_name):
if os.path.exists(out_path):
logger.info(f"Quantizing {model_name}")
mlmodel = ct.models.MLModel(out_path,
compute_units=ct.ComputeUnit.CPU_ONLY)
mlmodel = ct.compression_utils.affine_quantize_weights(
mlmodel, mode="linear")
mlmodel.save(out_path)
logger.info("Done")
else:
logger.info(
f"Skipped quantizing {model_name} (Not found at {out_path})")
def _compile_coreml_model(source_model_path, output_dir, final_name):
@ -194,6 +208,9 @@ def bundle_resources_for_swift_cli(args):
("unet", "Unet"),
("unet_chunk1", "UnetChunk1"),
("unet_chunk2", "UnetChunk2"),
("control-unet", "ControlledUnet"),
("control-unet_chunk1", "ControlledUnetChunk1"),
("control-unet_chunk2", "ControlledUnetChunk2"),
("safety_checker", "SafetyChecker")]:
source_path = _get_out_path(args, source_name)
if os.path.exists(source_path):
@ -204,6 +221,23 @@ def bundle_resources_for_swift_cli(args):
logger.warning(
f"{source_path} not found, skipping compilation to {target_name}.mlmodelc"
)
if args.convert_controlnet:
for controlnet_model_version in args.convert_controlnet:
controlnet_model_name = controlnet_model_version.replace("/", "_")
fname = f"ControlNet_{controlnet_model_name}.mlpackage"
source_path = os.path.join(args.o, fname)
controlnet_dir = os.path.join(resources_dir, "controlnet")
target_name = "".join([word.title() for word in re.split('_|-', controlnet_model_name)])
if os.path.exists(source_path):
target_path = _compile_coreml_model(source_path, controlnet_dir,
target_name)
logger.info(f"Compiled {source_path} to {target_path}")
else:
logger.warning(
f"{source_path} not found, skipping compilation to {target_name}.mlmodelc"
)
# Fetch and save vocabulary JSON file for text tokenizer
logger.info("Downloading and saving tokenizer vocab.json")
@ -543,7 +577,12 @@ def convert_vae_encoder(pipe, args):
def convert_unet(pipe, args):
""" Converts the UNet component of Stable Diffusion
"""
out_path = _get_out_path(args, "unet")
if args.unet_support_controlnet:
unet_name = "control-unet"
else:
unet_name = "unet"
out_path = _get_out_path(args, unet_name)
# Check if Unet was previously exported and then chunked
unet_chunks_exist = all(
@ -559,13 +598,6 @@ def convert_unet(pipe, args):
# If original Unet does not exist, export it from PyTorch+diffusers
elif not os.path.exists(out_path):
# Register the selected attention implementation globally
unet.ATTENTION_IMPLEMENTATION_IN_EFFECT = unet.AttentionImplementations[
args.attention_implementation]
logger.info(
f"Attention implementation in effect: {unet.ATTENTION_IMPLEMENTATION_IN_EFFECT}"
)
# Prepare sample input shapes and values
batch_size = 2 # for classifier-free guidance
sample_shape = (
@ -598,16 +630,6 @@ def convert_unet(pipe, args):
(batch_size)).to(torch.float32)),
("encoder_hidden_states", torch.rand(*encoder_hidden_states_shape))
])
sample_unet_inputs_spec = {
k: (v.shape, v.dtype)
for k, v in sample_unet_inputs.items()
}
logger.info(f"Sample inputs spec: {sample_unet_inputs_spec}")
# Initialize reference unet
reference_unet = unet.UNet2DConditionModel(**pipe.unet.config).eval()
load_state_dict_summary = reference_unet.load_state_dict(
pipe.unet.state_dict())
# Prepare inputs
baseline_sample_unet_inputs = deepcopy(sample_unet_inputs)
@ -615,6 +637,52 @@ def convert_unet(pipe, args):
"encoder_hidden_states"] = baseline_sample_unet_inputs[
"encoder_hidden_states"].squeeze(2).transpose(1, 2)
# Initialize reference unet
reference_unet = unet.UNet2DConditionModel(**pipe.unet.config).eval()
load_state_dict_summary = reference_unet.load_state_dict(
pipe.unet.state_dict())
if args.unet_support_controlnet:
from .unet import calculate_conv2d_output_shape
additional_residuals_shapes = []
in_size = pipe.unet.config.sample_size
# conv_in
out_h, out_w = calculate_conv2d_output_shape(in_size, in_size, reference_unet.conv_in)
additional_residuals_shapes.append(
(batch_size, reference_unet.conv_in.out_channels, out_h, out_w))
# down_blocks
for down_block in reference_unet.down_blocks:
additional_residuals_shapes += [
(batch_size, resnet.out_channels, out_h, out_w) for resnet in down_block.resnets
]
if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None:
for downsampler in down_block.downsamplers:
out_h, out_w = calculate_conv2d_output_shape(out_h, out_w, downsampler.conv)
additional_residuals_shapes.append(
(batch_size, down_block.downsamplers[-1].conv.out_channels, out_h, out_w))
# mid_block
additional_residuals_shapes.append(
(batch_size, reference_unet.mid_block.resnets[-1].out_channels, out_h, out_w)
)
baseline_sample_unet_inputs["down_block_additional_residuals"] = ()
for i, shape in enumerate(additional_residuals_shapes):
sample_residual_input = torch.rand(*shape)
sample_unet_inputs[f"additional_residual_{i}"] = sample_residual_input
if i == len(additional_residuals_shapes) - 1:
baseline_sample_unet_inputs["mid_block_additional_residual"] = sample_residual_input
else:
baseline_sample_unet_inputs["down_block_additional_residuals"] += (sample_residual_input, )
sample_unet_inputs_spec = {
k: (v.shape, v.dtype)
for k, v in sample_unet_inputs.items()
}
logger.info(f"Sample UNet inputs spec: {sample_unet_inputs_spec}")
# JIT trace
logger.info("JIT tracing..")
reference_unet = torch.jit.trace(reference_unet,
@ -624,7 +692,7 @@ def convert_unet(pipe, args):
if args.check_output_correctness:
baseline_out = pipe.unet(**baseline_sample_unet_inputs,
return_dict=False)[0].numpy()
reference_out = reference_unet(**sample_unet_inputs)[0].numpy()
reference_out = reference_unet(*sample_unet_inputs.values())[0].numpy()
report_correctness(baseline_out, reference_out,
"unet baseline to reference PyTorch")
@ -636,7 +704,7 @@ def convert_unet(pipe, args):
for k, v in sample_unet_inputs.items()
}
coreml_unet, out_path = _convert_to_coreml("unet", reference_unet,
coreml_unet, out_path = _convert_to_coreml(unet_name, reference_unet,
coreml_sample_unet_inputs,
["noise_pred"], args)
del reference_unet
@ -872,6 +940,174 @@ def convert_safety_checker(pipe, args):
del traced_safety_checker, coreml_safety_checker, pipe.safety_checker
gc.collect()
def convert_controlnet(pipe, args):
""" Converts each ControlNet for Stable Diffusion
"""
if not hasattr(pipe, "unet"):
raise RuntimeError(
"convert_unet() deletes pipe.unet to save RAM. "
"Please use convert_vae_encoder() before convert_unet()")
if not hasattr(pipe, "text_encoder"):
raise RuntimeError(
"convert_text_encoder() deletes pipe.text_encoder to save RAM. "
"Please use convert_unet() before convert_text_encoder()")
if args.model_version != "runwayml/stable-diffusion-v1-5":
logger.warning(
"The original ControlNet models were trained using Stable Diffusion v1.5. "
"It is possible that the converted model may not be compatible with controlnet.")
for i, controlnet_model_version in enumerate(args.convert_controlnet):
controlnet_model_name = controlnet_model_version.replace("/", "_")
fname = f"ControlNet_{controlnet_model_name}.mlpackage"
out_path = os.path.join(args.o, fname)
if os.path.exists(out_path):
logger.info(
f"`controlnet_{controlnet_model_name}` already exists at {out_path}, skipping conversion."
)
continue
if i == 0:
batch_size = 2 # for classifier-free guidance
sample_shape = (
batch_size, # B
pipe.unet.config.in_channels, # C
pipe.unet.config.sample_size, # H
pipe.unet.config.sample_size, # W
)
encoder_hidden_states_shape = (
batch_size,
pipe.text_encoder.config.hidden_size,
1,
pipe.text_encoder.config.max_position_embeddings,
)
controlnet_cond_shape = (
batch_size, # B
3, # C
(args.latent_h or pipe.unet.config.sample_size) * 8, # H
(args.latent_w or pipe.unet.config.sample_size) * 8, # w
)
# Create the scheduled timesteps for downstream use
DEFAULT_NUM_INFERENCE_STEPS = 50
pipe.scheduler.set_timesteps(DEFAULT_NUM_INFERENCE_STEPS)
# Prepare inputs
sample_controlnet_inputs = OrderedDict([
("sample", torch.rand(*sample_shape)),
("timestep",
torch.tensor([pipe.scheduler.timesteps[0].item()] *
(batch_size)).to(torch.float32)),
("encoder_hidden_states", torch.rand(*encoder_hidden_states_shape)),
("controlnet_cond", torch.rand(*controlnet_cond_shape)),
])
sample_controlnet_inputs_spec = {
k: (v.shape, v.dtype)
for k, v in sample_controlnet_inputs.items()
}
logger.info(
f"Sample ControlNet inputs spec: {sample_controlnet_inputs_spec}")
baseline_sample_controlnet_inputs = deepcopy(sample_controlnet_inputs)
baseline_sample_controlnet_inputs[
"encoder_hidden_states"] = baseline_sample_controlnet_inputs[
"encoder_hidden_states"].squeeze(2).transpose(1, 2)
# Import controlnet model and initialize reference controlnet
original_controlnet = ControlNetModel.from_pretrained(
controlnet_model_version,
use_auth_token=True
)
reference_controlnet = controlnet.ControlNetModel(**original_controlnet.config).eval()
load_state_dict_summary = reference_controlnet.load_state_dict(
original_controlnet.state_dict())
num_residuals = reference_controlnet.get_num_residuals()
output_keys = [f"additional_residual_{i}" for i in range(num_residuals)]
# JIT trace
logger.info("JIT tracing..")
reference_controlnet = torch.jit.trace(reference_controlnet,
list(sample_controlnet_inputs.values()))
logger.info("Done.")
if args.check_output_correctness:
baseline_out = original_controlnet(**baseline_sample_controlnet_inputs,
return_dict=False)
reference_out = reference_controlnet(*sample_controlnet_inputs.values())
baseline_down_residuals, baseline_mid_residuals = baseline_out
baseline_out = baseline_down_residuals + (baseline_mid_residuals,)
reference_down_residuals, reference_mid_residuals = reference_out
reference_out = reference_down_residuals +(reference_mid_residuals,)
for key, b_out, r_out in zip(output_keys, baseline_out, reference_out):
b_out = b_out.numpy()
r_out = r_out.numpy()
logger.info(f"Check {key} correctness")
report_correctness(b_out, r_out,
f"controlnet({controlnet_model_name}) baseline to reference PyTorch")
del original_controlnet
gc.collect()
coreml_sample_controlnet_inputs = {
k: v.numpy().astype(np.float16)
for k, v in sample_controlnet_inputs.items()
}
coreml_controlnet, out_path = _convert_to_coreml(f"controlnet_{controlnet_model_name}", reference_controlnet,
coreml_sample_controlnet_inputs,
output_keys, args,
out_path=out_path)
del reference_controlnet
gc.collect()
coreml_controlnet.author = f"Please refer to the Model Card available at huggingface.co/{controlnet_model_version}"
coreml_controlnet.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
coreml_controlnet.version = controlnet_model_version
coreml_controlnet.short_description = \
"ControlNet is a neural network structure to control diffusion models by adding extra conditions. " \
"Please refer to https://arxiv.org/abs/2302.05543 for details."
# Set the input descriptions
coreml_controlnet.input_description["sample"] = \
"The low resolution latent feature maps being denoised through reverse diffusion"
coreml_controlnet.input_description["timestep"] = \
"A value emitted by the associated scheduler object to condition the model on a given noise schedule"
coreml_controlnet.input_description["encoder_hidden_states"] = \
"Output embeddings from the associated text_encoder model to condition to generated image on text. " \
"A maximum of 77 tokens (~40 words) are allowed. Longer text is truncated. " \
"Shorter text does not reduce computation."
coreml_controlnet.input_description["controlnet_cond"] = \
"An additional input image for ControlNet to condition the generated images."
# Set the output descriptions
for i in range(num_residuals):
coreml_controlnet.output_description[f"additional_residual_{i}"] = \
"One of the outputs of each downsampling block in ControlNet. " \
"The value added to the corresponding resnet output in UNet."
_save_mlpackage(coreml_controlnet, out_path)
logger.info(f"Saved controlnet into {out_path}")
# Parity check PyTorch vs CoreML
if args.check_output_correctness:
coreml_out = coreml_controlnet.predict(coreml_sample_controlnet_inputs)
for key, b_out in zip(output_keys, baseline_out):
b_out = b_out.numpy()
logger.info(f"Check {key} correctness")
report_correctness(b_out, coreml_out[key],
"controlnet baseline PyTorch to reference CoreML")
del coreml_controlnet
gc.collect()
def main(args):
os.makedirs(args.o, exist_ok=True)
@ -883,6 +1119,13 @@ def main(args):
use_auth_token=True)
logger.info("Done.")
# Register the selected attention implementation globally
unet.ATTENTION_IMPLEMENTATION_IN_EFFECT = unet.AttentionImplementations[
args.attention_implementation]
logger.info(
f"Attention implementation in effect: {unet.ATTENTION_IMPLEMENTATION_IN_EFFECT}"
)
# Convert models
if args.convert_vae_decoder:
logger.info("Converting vae_decoder")
@ -894,6 +1137,11 @@ def main(args):
convert_vae_encoder(pipe, args)
logger.info("Converted vae_encoder")
if args.convert_controlnet:
logger.info("Converting controlnet")
convert_controlnet(pipe, args)
logger.info("Converted controlnet")
if args.convert_unet:
logger.info("Converting unet")
convert_unet(pipe, args)
@ -930,6 +1178,14 @@ def parser_spec():
parser.add_argument("--convert-vae-encoder", action="store_true")
parser.add_argument("--convert-unet", action="store_true")
parser.add_argument("--convert-safety-checker", action="store_true")
parser.add_argument(
"--convert-controlnet",
nargs="*",
type=str,
help=
"Converts a ControlNet model hosted on HuggingFace to coreML format. " \
"To convert multiple models, provide their names separated by spaces.",
)
parser.add_argument(
"--model-version",
default="CompVis/stable-diffusion-v1-4",
@ -989,6 +1245,13 @@ def parser_spec():
("If specified, quantize 16-bits weights to 8-bits weights in-place for all models. "
"Not recommended as the generated image quality degraded significantly after 8-bit weight quantization"
))
parser.add_argument(
"--unet-support-controlnet",
action="store_true",
help=
("If specified, enable unet to receive additional inputs from controlnet. "
"Each input added to corresponding resnet output."
))
# Swift CLI Resource Bundling
parser.add_argument(

View File

@ -937,6 +937,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample,
timestep,
encoder_hidden_states,
*additional_residuals,
):
# 0. Project (or look-up) time embeddings
t_emb = self.time_proj(timestep)
@ -965,10 +966,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples
if additional_residuals:
new_down_block_res_samples = ()
for i, down_block_res_sample in enumerate(down_block_res_samples):
down_block_res_sample = down_block_res_sample + additional_residuals[i]
new_down_block_res_samples += (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples
# 4. mid
sample = self.mid_block(sample,
emb,
encoder_hidden_states=encoder_hidden_states)
if additional_residuals:
sample = sample + additional_residuals[-1]
# 5. up
for upsample_block in self.up_blocks:
@ -1081,3 +1092,13 @@ def get_up_block(
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{up_block_type} does not exist.")
def calculate_conv2d_output_shape(in_h, in_w, conv2d_layer):
k_h, k_w = conv2d_layer.kernel_size
pad_h, pad_w = conv2d_layer.padding
stride_h, stride_w = conv2d_layer.stride
out_h = math.floor((in_h + 2 * pad_h - k_h) / stride_h + 1)
out_w = math.floor((in_w + 2 * pad_w - k_w) / stride_w + 1)
return out_h, out_w

View File

@ -63,8 +63,8 @@ extension CGImage {
return cgImage
}
public var plannerRGBShapedArray: MLShapedArray<Float32> {
get throws {
public func plannerRGBShapedArray(minValue: Float, maxValue: Float)
throws -> MLShapedArray<Float32> {
guard
var sourceFormat = vImage_CGImageFormat(cgImage: self),
var mediumFormat = vImage_CGImageFormat(
@ -100,8 +100,8 @@ extension CGImage {
var destinationG = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
var destinationB = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout<Float>.size))
var minFloat: [Float] = [-1.0, -1.0, -1.0, -1.0]
var maxFloat: [Float] = [1.0, 1.0, 1.0, 1.0]
var minFloat: [Float] = Array(repeating: minValue, count: 4)
var maxFloat: [Float] = Array(repeating: maxValue, count: 4)
vImageConvert_ARGB8888toPlanarF(&mediumDesination, &destinationA, &destinationR, &destinationG, &destinationB, &maxFloat, &minFloat, .zero)
@ -114,7 +114,6 @@ extension CGImage {
let shapedArray = MLShapedArray<Float32>(data: imageData, shape: [1, 3, self.width, self.height])
return shapedArray
}
}
}

View File

@ -0,0 +1,127 @@
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
import CoreML
@available(iOS 16.2, macOS 13.1, *)
public struct ControlNet: ResourceManaging {
var models: [ManagedMLModel]
public init(modelAt urls: [URL],
configuration: MLModelConfiguration) {
self.models = urls.map { ManagedMLModel(modelAt: $0, configuration: configuration) }
}
/// Load resources.
public func loadResources() throws {
for model in models {
try model.loadResources()
}
}
/// Unload the underlying model to free up memory
public func unloadResources() {
for model in models {
model.unloadResources()
}
}
/// Pre-warm resources
public func prewarmResources() throws {
// Override default to pre-warm each model
for model in models {
try model.loadResources()
model.unloadResources()
}
}
var inputImageDescriptions: [MLFeatureDescription] {
models.map { model in
try! model.perform {
$0.modelDescription.inputDescriptionsByName["controlnet_cond"]!
}
}
}
/// The expected shape of the models image input
public var inputImageShapes: [[Int]] {
inputImageDescriptions.map { desc in
desc.multiArrayConstraint!.shape.map { $0.intValue }
}
}
/// Calculate additional inputs for Unet to generate intended image following provided images
///
/// - Parameters:
/// - latents: Batch of latent samples in an array
/// - timeStep: Current diffusion timestep
/// - hiddenStates: Hidden state to condition on
/// - images: Images for each ControlNet
/// - Returns: Array of predicted noise residuals
func execute(
latents: [MLShapedArray<Float32>],
timeStep: Int,
hiddenStates: MLShapedArray<Float32>,
images: [MLShapedArray<Float32>]
) throws -> [[String: MLShapedArray<Float32>]] {
// Match time step batch dimension to the model / latent samples
let t = MLShapedArray(scalars: [Float(timeStep), Float(timeStep)], shape: [2])
var outputs: [[String: MLShapedArray<Float32>]] = []
for (modelIndex, model) in models.enumerated() {
let inputs = try latents.map { latent in
let dict: [String: Any] = [
"sample": MLMultiArray(latent),
"timestep": MLMultiArray(t),
"encoder_hidden_states": MLMultiArray(hiddenStates),
"controlnet_cond": MLMultiArray(images[modelIndex])
]
return try MLDictionaryFeatureProvider(dictionary: dict)
}
let batch = MLArrayBatchProvider(array: inputs)
let results = try model.perform {
try $0.predictions(fromBatch: batch)
}
// pre-allocate MLShapedArray with a specific shape in outputs
if outputs.isEmpty {
outputs = initOutputs(
batch: latents.count,
shapes: results.features(at: 0).featureValueDictionary
)
}
for n in 0..<results.count {
let result = results.features(at: n)
for k in result.featureNames {
let newValue = result.featureValue(for: k)!.shapedArrayValue(of: Float32.self)!
if modelIndex == 0 {
outputs[n][k] = newValue
} else {
outputs[n][k]!.withUnsafeMutableShapedBufferPointer { pt, _, _ in
for (i, v) in newValue.scalars.enumerated() { pt[i] += v }
}
}
}
}
}
return outputs
}
private func initOutputs(batch: Int, shapes: [String: MLFeatureValue]) -> [[String: MLShapedArray<Float32>]] {
var output: [String: MLShapedArray<Float32>] = [:]
for (outputName, featureValue) in shapes {
output[outputName] = MLShapedArray<Float32>(
repeating: 0.0,
shape: featureValue.multiArrayValue!.shape.map { $0.intValue }
)
}
return Array(repeating: output, count: batch)
}
}

View File

@ -50,7 +50,7 @@ public struct Encoder: ResourceManaging {
scaleFactor: Float32,
random: inout RandomSource
) throws -> MLShapedArray<Float32> {
let imageData = try image.plannerRGBShapedArray
let imageData = try image.plannerRGBShapedArray(minValue: -1.0, maxValue: 1.0)
guard imageData.shape == inputShape else {
// TODO: Consider auto resizing and croping similar to how Vision or CoreML auto-generated Swift code can accomplish with `MLFeatureValue`
throw Error.sampleInputShapeNotCorrect

View File

@ -18,6 +18,10 @@ public extension StableDiffusionPipeline {
public let safetyCheckerURL: URL
public let vocabURL: URL
public let mergesURL: URL
public let controlNetDirURL: URL
public let controlledUnetURL: URL
public let controlledUnetChunk1URL: URL
public let controlledUnetChunk2URL: URL
public init(resourcesAt baseURL: URL) {
textEncoderURL = baseURL.appending(path: "TextEncoder.mlmodelc")
@ -29,6 +33,10 @@ public extension StableDiffusionPipeline {
safetyCheckerURL = baseURL.appending(path: "SafetyChecker.mlmodelc")
vocabURL = baseURL.appending(path: "vocab.json")
mergesURL = baseURL.appending(path: "merges.txt")
controlNetDirURL = baseURL.appending(path: "controlnet")
controlledUnetURL = baseURL.appending(path: "ControlledUnet.mlmodelc")
controlledUnetChunk1URL = baseURL.appending(path: "ControlledUnetChunk1.mlmodelc")
controlledUnetChunk2URL = baseURL.appending(path: "ControlledUnetChunk2.mlmodelc")
}
}
@ -38,12 +46,14 @@ public extension StableDiffusionPipeline {
/// - Parameters:
/// - baseURL: URL pointing to directory holding all model
/// and tokenization resources
/// - controlNetModelNames: Specify ControlNet models to use in generation
/// - configuration: The configuration to load model resources with
/// - disableSafety: Load time disable of safety to save memory
/// - reduceMemory: Setup pipeline in reduced memory mode
/// - Returns:
/// Pipeline ready for image generation if all necessary resources loaded
init(resourcesAt baseURL: URL,
controlNet controlNetModelNames: [String],
configuration config: MLModelConfiguration = .init(),
disableSafety: Bool = false,
reduceMemory: Bool = false) throws {
@ -56,15 +66,38 @@ public extension StableDiffusionPipeline {
let textEncoder = TextEncoder(tokenizer: tokenizer,
modelAt: urls.textEncoderURL,
configuration: config)
// ControlNet model
var controlNet: ControlNet? = nil
let controlNetURLs = controlNetModelNames.map { model in
let fileName = model + ".mlmodelc"
return urls.controlNetDirURL.appending(path: fileName)
}
if !controlNetURLs.isEmpty {
controlNet = ControlNet(modelAt: controlNetURLs, configuration: config)
}
// Unet model
let unet: Unet
if FileManager.default.fileExists(atPath: urls.unetChunk1URL.path) &&
FileManager.default.fileExists(atPath: urls.unetChunk2URL.path) {
unet = Unet(chunksAt: [urls.unetChunk1URL, urls.unetChunk2URL],
let unetURL: URL, unetChunk1URL: URL, unetChunk2URL: URL
// if ControlNet available, Unet supports additional inputs from ControlNet
if controlNet == nil {
unetURL = urls.unetURL
unetChunk1URL = urls.unetChunk1URL
unetChunk2URL = urls.unetChunk2URL
} else {
unetURL = urls.controlledUnetURL
unetChunk1URL = urls.controlledUnetChunk1URL
unetChunk2URL = urls.controlledUnetChunk2URL
}
if FileManager.default.fileExists(atPath: unetChunk1URL.path) &&
FileManager.default.fileExists(atPath: unetChunk2URL.path) {
unet = Unet(chunksAt: [unetChunk1URL, unetChunk2URL],
configuration: config)
} else {
unet = Unet(modelAt: urls.unetURL, configuration: config)
unet = Unet(modelAt: unetURL, configuration: config)
}
// Image Decoder
@ -90,6 +123,7 @@ public extension StableDiffusionPipeline {
unet: unet,
decoder: decoder,
encoder: encoder,
controlNet: controlNet,
safetyChecker: safetyChecker,
reduceMemory: reduceMemory)
}

View File

@ -33,6 +33,8 @@ extension StableDiffusionPipeline {
public var seed: UInt32 = 0
/// Controls the influence of the text prompt on sampling process (0=random images)
public var guidanceScale: Float = 7.5
/// List of Images for available ControlNet Models
public var controlNetInputs: [CGImage] = []
/// Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
public var disableSafety: Bool = false
/// The type of Scheduler to use.

View File

@ -47,6 +47,9 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// Optional model for checking safety of generated image
var safetyChecker: SafetyChecker? = nil
/// Optional model used before Unet to control generated images by additonal inputs
var controlNet: ControlNet? = nil
/// Reports whether this pipeline can perform safety checks
public var canSafetyCheck: Bool {
@ -67,6 +70,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
/// - textEncoder: Model for encoding tokenized text
/// - unet: Model for noise prediction on latent samples
/// - decoder: Model for decoding latent sample to image
/// - controlNet: Optional model to control generated images by additonal inputs
/// - safetyChecker: Optional model for checking safety of generated images
/// - reduceMemory: Option to enable reduced memory mode
/// - Returns: Pipeline ready for image generation
@ -74,12 +78,14 @@ public struct StableDiffusionPipeline: ResourceManaging {
unet: Unet,
decoder: Decoder,
encoder: Encoder?,
controlNet: ControlNet? = nil,
safetyChecker: SafetyChecker? = nil,
reduceMemory: Bool = false) {
self.textEncoder = textEncoder
self.unet = unet
self.decoder = decoder
self.encoder = encoder
self.controlNet = controlNet
self.safetyChecker = safetyChecker
self.reduceMemory = reduceMemory
}
@ -95,6 +101,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
try textEncoder.loadResources()
try unet.loadResources()
try decoder.loadResources()
try controlNet?.loadResources()
try safetyChecker?.loadResources()
}
}
@ -104,6 +111,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
textEncoder.unloadResources()
unet.unloadResources()
decoder.unloadResources()
controlNet?.unloadResources()
safetyChecker?.unloadResources()
}
@ -112,6 +120,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
try textEncoder.prewarmResources()
try unet.prewarmResources()
try decoder.prewarmResources()
try controlNet?.prewarmResources()
try safetyChecker?.prewarmResources()
}
@ -154,6 +163,15 @@ public struct StableDiffusionPipeline: ResourceManaging {
// Generate random latent samples from specified seed
var latents: [MLShapedArray<Float32>] = try generateLatentSamples(configuration: config, scheduler: scheduler[0])
let timestepStrength: Float? = config.mode == .imageToImage ? config.strength : nil
// Convert cgImage for ControlNet into MLShapedArray
let controlNetConds = try config.controlNetInputs.map { cgImage in
let shapedArray = try cgImage.plannerRGBShapedArray(minValue: 0.0, maxValue: 1.0)
return MLShapedArray(
concatenating: [shapedArray, shapedArray],
alongAxis: 0
)
}
// De-noising loop
let timeSteps: [Int] = scheduler[0].calculateTimesteps(strength: timestepStrength)
@ -164,13 +182,22 @@ public struct StableDiffusionPipeline: ResourceManaging {
let latentUnetInput = latents.map {
MLShapedArray<Float32>(concatenating: [$0, $0], alongAxis: 0)
}
// Before Unet, execute controlNet and add the output into Unet inputs
let additionalResiduals = try controlNet?.execute(
latents: latentUnetInput,
timeStep: t,
hiddenStates: hiddenStates,
images: controlNetConds
)
// Predict noise residuals from latent samples
// and current time step conditioned on hidden states
var noise = try unet.predictNoise(
latents: latentUnetInput,
timeStep: t,
hiddenStates: hiddenStates
hiddenStates: hiddenStates,
additionalResiduals: additionalResiduals
)
noise = performGuidance(noise, config.guidanceScale)
@ -201,6 +228,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
}
if reduceMemory {
controlNet?.unloadResources()
unet.unloadResources()
}

View File

@ -79,19 +79,25 @@ public struct Unet: ResourceManaging {
func predictNoise(
latents: [MLShapedArray<Float32>],
timeStep: Int,
hiddenStates: MLShapedArray<Float32>
hiddenStates: MLShapedArray<Float32>,
additionalResiduals: [[String: MLShapedArray<Float32>]]? = nil
) throws -> [MLShapedArray<Float32>] {
// Match time step batch dimension to the model / latent samples
let t = MLShapedArray<Float32>(scalars:[Float(timeStep), Float(timeStep)],shape:[2])
// Form batch input to model
let inputs = try latents.map {
let dict: [String: Any] = [
"sample" : MLMultiArray($0),
let inputs = try latents.enumerated().map {
var dict: [String: Any] = [
"sample" : MLMultiArray($0.element),
"timestep" : MLMultiArray(t),
"encoder_hidden_states": MLMultiArray(hiddenStates)
]
if let residuals = additionalResiduals?[$0.offset] {
for (k, v) in residuals {
dict[k] = MLMultiArray(v)
}
}
return try MLDictionaryFeatureProvider(dictionary: dict)
}
let batch = MLArrayBatchProvider(array: inputs)

View File

@ -8,6 +8,7 @@ import Foundation
import StableDiffusion
import UniformTypeIdentifiers
import Cocoa
import CoreImage
@available(iOS 16.2, macOS 13.1, *)
struct StableDiffusionSample: ParsableCommand {
@ -71,6 +72,18 @@ struct StableDiffusionSample: ParsableCommand {
@Option(help: "Random number generator to use, one of {numpy, torch}")
var rng: RNGOption = .numpy
@Option(
parsing: .upToNextOption,
help: "ControlNet models used in image generation (enter file names in Resources/controlnet without extension)"
)
var controlnet: [String] = []
@Option(
parsing: .upToNextOption,
help: "image for each controlNet model (corresponding to the same order as --controlnet)"
)
var controlnetInputs: [String] = []
@Flag(help: "Disable safety checking")
var disableSafety: Bool = false
@ -90,6 +103,7 @@ struct StableDiffusionSample: ParsableCommand {
log("Loading resources and creating pipeline\n")
log("(Note: This can take a while the first time using these resources)\n")
let pipeline = try StableDiffusionPipeline(resourcesAt: resourceURL,
controlNet: controlnet,
configuration: config,
disableSafety: disableSafety,
reduceMemory: reduceMemory)
@ -99,14 +113,7 @@ struct StableDiffusionSample: ParsableCommand {
if let image {
let imageURL = URL(filePath: image)
do {
let imageData = try Data(contentsOf: imageURL)
guard
let nsImage = NSImage(data: imageData),
let loadedImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil)
else {
throw RunError.resources("Starting Image not available \(resourcePath)")
}
startingImage = loadedImage
startingImage = try convertImageToCGImage(imageURL: imageURL)
} catch let error {
throw RunError.resources("Starting image not found \(imageURL), error: \(error)")
}
@ -114,6 +121,21 @@ struct StableDiffusionSample: ParsableCommand {
} else {
startingImage = nil
}
// convert image for ControlNet into CGImage when controlNet available
let controlNetInputs: [CGImage]
if !controlnet.isEmpty {
controlNetInputs = try controlnetInputs.map { imagePath in
let imageURL = URL(filePath: imagePath)
do {
return try convertImageToCGImage(imageURL: imageURL)
} catch let error {
throw RunError.resources("Image for ControlNet not found \(imageURL), error: \(error)")
}
}
} else {
controlNetInputs = []
}
log("Sampling ...\n")
let sampleTimer = SampleTimer()
@ -127,6 +149,7 @@ struct StableDiffusionSample: ParsableCommand {
pipelineConfig.imageCount = imageCount
pipelineConfig.stepCount = stepCount
pipelineConfig.seed = seed
pipelineConfig.controlNetInputs = controlNetInputs
pipelineConfig.guidanceScale = guidanceScale
pipelineConfig.schedulerType = scheduler.stableDiffusionScheduler
pipelineConfig.rngType = rng.stableDiffusionRNG
@ -144,6 +167,17 @@ struct StableDiffusionSample: ParsableCommand {
_ = try saveImages(images, logNames: true)
}
func convertImageToCGImage(imageURL: URL) throws -> CGImage {
let imageData = try Data(contentsOf: imageURL)
guard
let nsImage = NSImage(data: imageData),
let loadedImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil)
else {
throw RunError.resources("Image not available \(resourcePath)")
}
return loadedImage
}
func handleProgress(
_ progress: StableDiffusionPipeline.Progress,