From 99efd386f760b279d085a0ea4f4b50822343d8d6 Mon Sep 17 00:00:00 2001 From: Vincent Esche Date: Tue, 24 Sep 2019 14:11:31 +0200 Subject: [PATCH] =?UTF-8?q?Added=20`muladd`/`muladdInPlace`=20to=20?= =?UTF-8?q?=E2=80=98Arithmetic.swift=E2=80=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Sources/Surge/Arithmetic/Arithmetic.swift | 36 +++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/Sources/Surge/Arithmetic/Arithmetic.swift b/Sources/Surge/Arithmetic/Arithmetic.swift index de7d69d..0caea49 100644 --- a/Sources/Surge/Arithmetic/Arithmetic.swift +++ b/Sources/Surge/Arithmetic/Arithmetic.swift @@ -220,6 +220,42 @@ public func .-= (lh return subInPlace(&lhs, rhs) } +// MARK: - Multiply Addition + +func muladd(_ lhs: L, _ rhs: R, _ alpha: Float) -> [Float] where L.Element == Float, R.Element == Float { + var results = Array(lhs) + muladdInPlace(&results, rhs, alpha) + return results +} + +func muladd(_ lhs: L, _ rhs: R, _ alpha: Double) -> [Double] where L.Element == Double, R.Element == Double { + var results = Array(lhs) + muladdInPlace(&results, rhs, alpha) + return results +} + +// MARK: - Multiply Addition: In Place + +func muladdInPlace(_ lhs: inout L, _ rhs: R, _ alpha: Float) where L.Element == Float, R.Element == Float { + precondition(lhs.count == rhs.count, "Collections must have the same size") + let elementCount: Int32 = numericCast(lhs.count) + lhs.withUnsafeMutableMemory { lhsMemory in + rhs.withUnsafeMemory { rhsMemory in + cblas_saxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride)) + } + } +} + +func muladdInPlace(_ lhs: inout L, _ rhs: R, _ alpha: Double) where L.Element == Double, R.Element == Double { + precondition(lhs.count == rhs.count, "Collections must have the same size") + let elementCount: Int32 = numericCast(lhs.count) + lhs.withUnsafeMutableMemory { lhsMemory in + rhs.withUnsafeMemory { rhsMemory in + cblas_daxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride)) + } + } +} + // MARK: - Multiplication public func mul(_ lhs: L, _ rhs: Float) -> [Float] where L.Element == Float {