ml-stable-diffusion/swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler...

189 lines
8.3 KiB
Swift
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. and The HuggingFace Team. All Rights Reserved.
import Accelerate
import CoreML
/// A scheduler used to compute a de-noised image
///
/// This implementation matches:
/// [Hugging Face Diffusers DPMSolverMultistepScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py)
///
/// It uses the DPM-Solver++ algorithm: [code](https://github.com/LuChengTHU/dpm-solver) [paper](https://arxiv.org/abs/2211.01095).
/// Limitations:
/// - Only implemented for DPM-Solver++ algorithm (not DPM-Solver).
/// - Second order only.
/// - Assumes the model predicts epsilon.
/// - No dynamic thresholding.
/// - `midpoint` solver algorithm.
@available(iOS 16.2, macOS 13.1, *)
public final class DPMSolverMultistepScheduler: Scheduler {
public let trainStepCount: Int
public let inferenceStepCount: Int
public let betas: [Float]
public let alphas: [Float]
public let alphasCumProd: [Float]
public let timeSteps: [Int]
public let alpha_t: [Float]
public let sigma_t: [Float]
public let lambda_t: [Float]
public let solverOrder = 2
private(set) var lowerOrderStepped = 0
/// Whether to use lower-order solvers in the final steps. Only valid for less than 15 inference steps.
/// We empirically find this trick can stabilize the sampling of DPM-Solver, especially with 10 or fewer steps.
public let useLowerOrderFinal = true
// Stores solverOrder (2) items
private(set) var modelOutputs: [MLShapedArray<Float32>] = []
/// Create a scheduler that uses a second order DPM-Solver++ algorithm.
///
/// - Parameters:
/// - stepCount: Number of inference steps to schedule
/// - trainStepCount: Number of training diffusion steps
/// - betaSchedule: Method to schedule betas from betaStart to betaEnd
/// - betaStart: The starting value of beta for inference
/// - betaEnd: The end value for beta for inference
/// - Returns: A scheduler ready for its first step
public init(
stepCount: Int = 50,
trainStepCount: Int = 1000,
betaSchedule: BetaSchedule = .scaledLinear,
betaStart: Float = 0.00085,
betaEnd: Float = 0.012
) {
self.trainStepCount = trainStepCount
self.inferenceStepCount = stepCount
switch betaSchedule {
case .linear:
self.betas = linspace(betaStart, betaEnd, trainStepCount)
case .scaledLinear:
self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 })
}
self.alphas = betas.map({ 1.0 - $0 })
var alphasCumProd = self.alphas
for i in 1..<alphasCumProd.count {
alphasCumProd[i] *= alphasCumProd[i - 1]
}
self.alphasCumProd = alphasCumProd
// Currently we only support VP-type noise shedule
self.alpha_t = vForce.sqrt(self.alphasCumProd)
self.sigma_t = vForce.sqrt(vDSP.subtract([Float](repeating: 1, count: self.alphasCumProd.count), self.alphasCumProd))
self.lambda_t = zip(self.alpha_t, self.sigma_t).map { α, σ in log(α) - log(σ) }
self.timeSteps = linspace(0, Float(self.trainStepCount-1), stepCount+1).dropFirst().reversed().map { Int(round($0)) }
}
/// Convert the model output to the corresponding type the algorithm needs.
/// This implementation is for second-order DPM-Solver++ assuming epsilon prediction.
func convertModelOutput(modelOutput: MLShapedArray<Float32>, timestep: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
assert(modelOutput.scalarCount == sample.scalarCount)
let scalarCount = modelOutput.scalarCount
let (alpha_t, sigma_t) = (self.alpha_t[timestep], self.sigma_t[timestep])
return MLShapedArray(unsafeUninitializedShape: modelOutput.shape) { scalars, _ in
assert(scalars.count == scalarCount)
modelOutput.withUnsafeShapedBufferPointer { modelOutput, _, _ in
sample.withUnsafeShapedBufferPointer { sample, _, _ in
for i in 0 ..< scalarCount {
scalars.initializeElement(at: i, to: (sample[i] - modelOutput[i] * sigma_t) / alpha_t)
}
}
}
}
}
/// One step for the first-order DPM-Solver (equivalent to DDIM).
/// See https://arxiv.org/abs/2206.00927 for the detailed derivation.
/// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
func firstOrderUpdate(
modelOutput: MLShapedArray<Float32>,
timestep: Int,
prevTimestep: Int,
sample: MLShapedArray<Float32>
) -> MLShapedArray<Float32> {
let (p_lambda_t, lambda_s) = (Double(lambda_t[prevTimestep]), Double(lambda_t[timestep]))
let p_alpha_t = Double(alpha_t[prevTimestep])
let (p_sigma_t, sigma_s) = (Double(sigma_t[prevTimestep]), Double(sigma_t[timestep]))
let h = p_lambda_t - lambda_s
// x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
let x_t = weightedSum(
[p_sigma_t / sigma_s, -p_alpha_t * (exp(-h) - 1)],
[sample, modelOutput]
)
return x_t
}
/// One step for the second-order multistep DPM-Solver++ algorithm, using the midpoint method.
/// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
func secondOrderUpdate(
modelOutputs: [MLShapedArray<Float32>],
timesteps: [Int],
prevTimestep t: Int,
sample: MLShapedArray<Float32>
) -> MLShapedArray<Float32> {
let (s0, s1) = (timesteps[back: 1], timesteps[back: 2])
let (m0, m1) = (modelOutputs[back: 1], modelOutputs[back: 2])
let (p_lambda_t, lambda_s0, lambda_s1) = (Double(lambda_t[t]), Double(lambda_t[s0]), Double(lambda_t[s1]))
let p_alpha_t = Double(alpha_t[t])
let (p_sigma_t, sigma_s0) = (Double(sigma_t[t]), Double(sigma_t[s0]))
let (h, h_0) = (p_lambda_t - lambda_s0, lambda_s0 - lambda_s1)
let r0 = h_0 / h
let D0 = m0
// D1 = (1.0 / r0) * (m0 - m1)
let D1 = weightedSum(
[1/r0, -1/r0],
[m0, m1]
)
// See https://arxiv.org/abs/2211.01095 for detailed derivations
// x_t = (
// (sigma_t / sigma_s0) * sample
// - (alpha_t * (torch.exp(-h) - 1.0)) * D0
// - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
// )
let x_t = weightedSum(
[p_sigma_t/sigma_s0, -p_alpha_t * (exp(-h) - 1), -0.5 * p_alpha_t * (exp(-h) - 1)],
[sample, D0, D1]
)
return x_t
}
public func step(output: MLShapedArray<Float32>, timeStep t: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
let stepIndex = timeSteps.firstIndex(of: t) ?? timeSteps.count - 1
let prevTimestep = stepIndex == timeSteps.count - 1 ? 0 : timeSteps[stepIndex + 1]
let lowerOrderFinal = useLowerOrderFinal && stepIndex == timeSteps.count - 1 && timeSteps.count < 15
let lowerOrderSecond = useLowerOrderFinal && stepIndex == timeSteps.count - 2 && timeSteps.count < 15
let lowerOrder = lowerOrderStepped < 1 || lowerOrderFinal || lowerOrderSecond
let modelOutput = convertModelOutput(modelOutput: output, timestep: t, sample: sample)
if modelOutputs.count == solverOrder { modelOutputs.removeFirst() }
modelOutputs.append(modelOutput)
let prevSample: MLShapedArray<Float32>
if lowerOrder {
prevSample = firstOrderUpdate(modelOutput: modelOutput, timestep: t, prevTimestep: prevTimestep, sample: sample)
} else {
prevSample = secondOrderUpdate(
modelOutputs: modelOutputs,
timesteps: [timeSteps[stepIndex - 1], t],
prevTimestep: prevTimestep,
sample: sample
)
}
if lowerOrderStepped < solverOrder {
lowerOrderStepped += 1
}
return prevSample
}
}