ml-stable-diffusion/swift/StableDiffusion/pipeline/ControlNet.swift

128 lines
4.4 KiB
Swift

// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
import CoreML
@available(iOS 16.2, macOS 13.1, *)
public struct ControlNet: ResourceManaging {
var models: [ManagedMLModel]
public init(modelAt urls: [URL],
configuration: MLModelConfiguration) {
self.models = urls.map { ManagedMLModel(modelAt: $0, configuration: configuration) }
}
/// Load resources.
public func loadResources() throws {
for model in models {
try model.loadResources()
}
}
/// Unload the underlying model to free up memory
public func unloadResources() {
for model in models {
model.unloadResources()
}
}
/// Pre-warm resources
public func prewarmResources() throws {
// Override default to pre-warm each model
for model in models {
try model.loadResources()
model.unloadResources()
}
}
var inputImageDescriptions: [MLFeatureDescription] {
models.map { model in
try! model.perform {
$0.modelDescription.inputDescriptionsByName["controlnet_cond"]!
}
}
}
/// The expected shape of the models image input
public var inputImageShapes: [[Int]] {
inputImageDescriptions.map { desc in
desc.multiArrayConstraint!.shape.map { $0.intValue }
}
}
/// Calculate additional inputs for Unet to generate intended image following provided images
///
/// - Parameters:
/// - latents: Batch of latent samples in an array
/// - timeStep: Current diffusion timestep
/// - hiddenStates: Hidden state to condition on
/// - images: Images for each ControlNet
/// - Returns: Array of predicted noise residuals
func execute(
latents: [MLShapedArray<Float32>],
timeStep: Int,
hiddenStates: MLShapedArray<Float32>,
images: [MLShapedArray<Float32>]
) throws -> [[String: MLShapedArray<Float32>]] {
// Match time step batch dimension to the model / latent samples
let t = MLShapedArray(scalars: [Float(timeStep), Float(timeStep)], shape: [2])
var outputs: [[String: MLShapedArray<Float32>]] = []
for (modelIndex, model) in models.enumerated() {
let inputs = try latents.map { latent in
let dict: [String: Any] = [
"sample": MLMultiArray(latent),
"timestep": MLMultiArray(t),
"encoder_hidden_states": MLMultiArray(hiddenStates),
"controlnet_cond": MLMultiArray(images[modelIndex])
]
return try MLDictionaryFeatureProvider(dictionary: dict)
}
let batch = MLArrayBatchProvider(array: inputs)
let results = try model.perform {
try $0.predictions(fromBatch: batch)
}
// pre-allocate MLShapedArray with a specific shape in outputs
if outputs.isEmpty {
outputs = initOutputs(
batch: latents.count,
shapes: results.features(at: 0).featureValueDictionary
)
}
for n in 0..<results.count {
let result = results.features(at: n)
for k in result.featureNames {
let newValue = result.featureValue(for: k)!.shapedArrayValue(of: Float32.self)!
if modelIndex == 0 {
outputs[n][k] = newValue
} else {
outputs[n][k]!.withUnsafeMutableShapedBufferPointer { pt, _, _ in
for (i, v) in newValue.scalars.enumerated() { pt[i] += v }
}
}
}
}
}
return outputs
}
private func initOutputs(batch: Int, shapes: [String: MLFeatureValue]) -> [[String: MLShapedArray<Float32>]] {
var output: [String: MLShapedArray<Float32>] = [:]
for (outputName, featureValue) in shapes {
output[outputName] = MLShapedArray<Float32>(
repeating: 0.0,
shape: featureValue.multiArrayValue!.shape.map { $0.intValue }
)
}
return Array(repeating: output, count: batch)
}
}