From 88e8cb3ef181167047d0298ca9b0c289fab99e39 Mon Sep 17 00:00:00 2001 From: Vincent Esche Date: Tue, 24 Sep 2019 14:09:04 +0200 Subject: [PATCH] =?UTF-8?q?Turned=20`add=E2=80=A6`/`sub=E2=80=A6`=20into?= =?UTF-8?q?=20thin=20wrappers=20of=20`muladd=E2=80=A6`=20(in=20=E2=80=98Ar?= =?UTF-8?q?ithmetic.swift=E2=80=99)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Sources/Surge/Arithmetic/Arithmetic.swift | 76 +++++------------------ 1 file changed, 14 insertions(+), 62 deletions(-) diff --git a/Sources/Surge/Arithmetic/Arithmetic.swift b/Sources/Surge/Arithmetic/Arithmetic.swift index 0caea49..a13cae1 100644 --- a/Sources/Surge/Arithmetic/Arithmetic.swift +++ b/Sources/Surge/Arithmetic/Arithmetic.swift @@ -41,15 +41,17 @@ public func + (lhs: L, rhs: Double) -> [Double] where // MARK: - Addition: In Place func addInPlace(_ lhs: inout L, _ rhs: Float) where L.Element == Float { + var scalar = rhs + lhs.withUnsafeMutableMemory { lm in - var scalar = rhs vDSP_vsadd(lm.pointer, numericCast(lm.stride), &scalar, lm.pointer, numericCast(lm.stride), numericCast(lm.count)) } } func addInPlace(_ lhs: inout L, _ rhs: Double) where L.Element == Double { + var scalar = rhs + lhs.withUnsafeMutableMemory { lm in - var scalar = rhs vDSP_vsaddD(lm.pointer, numericCast(lm.stride), &scalar, lm.pointer, numericCast(lm.stride), numericCast(lm.count)) } } @@ -65,25 +67,11 @@ public func +=(lhs: inout L, rhs: Double) wher // MARK: - Element-wise Addition public func add(_ lhs: L, _ rhs: R) -> [Float] where L.Element == Float, R.Element == Float { - precondition(lhs.count == rhs.count, "Collections must have the same size") - var results = [Float](rhs) - lhs.withUnsafeMemory { lhsMemory in - results.withUnsafeMutableBufferPointer { bufferPointer in - cblas_saxpy(numericCast(lhsMemory.count), 1.0, lhsMemory.pointer, numericCast(lhsMemory.stride), bufferPointer.baseAddress, 1) - } - } - return results + muladd(lhs, rhs, 1.0) } public func add(_ lhs: L, _ rhs: R) -> [Double] where L.Element == Double, R.Element == Double { - precondition(lhs.count == rhs.count, "Collections must have the same size") - var results = [Double](rhs) - lhs.withUnsafeMemory { lhsMemory in - results.withUnsafeMutableBufferPointer { bufferPointer in - cblas_daxpy(numericCast(lhsMemory.count), 1.0, lhsMemory.pointer, numericCast(lhsMemory.stride), bufferPointer.baseAddress, 1) - } - } - return results + muladd(lhs, rhs, 1.0) } public func .+ (lhs: L, rhs: R) -> [Float] where L.Element == Float, R.Element == Float { @@ -97,19 +85,11 @@ public func .+ (lhs: L, rh // MARK: - Element-wise Addition: In Place func addInPlace(_ lhs: inout L, _ rhs: R) where L.Element == Float, R.Element == Float { - lhs.withUnsafeMutableMemory { lm in - rhs.withUnsafeMemory { rm in - vDSP_vadd(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count)) - } - } + muladdInPlace(&lhs, rhs, 1.0) } func addInPlace(_ lhs: inout L, _ rhs: R) where L.Element == Double, R.Element == Double { - lhs.withUnsafeMutableMemory { lm in - rhs.withUnsafeMemory { rm in - vDSP_vaddD(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count)) - } - } + muladdInPlace(&lhs, rhs, 1.0) } public func .+= (lhs: inout L, rhs: R) where L.Element == Float, R.Element == Float { @@ -141,17 +121,11 @@ public func - (lhs: L, rhs: Double) -> [Double] where // MARK: - Subtraction: In Place func subInPlace(_ lhs: inout L, _ rhs: Float) where L.Element == Float { - lhs.withUnsafeMutableMemory { lm in - var scalar = -rhs - vDSP_vsadd(lm.pointer, numericCast(lm.stride), &scalar, lm.pointer, numericCast(lm.stride), numericCast(lm.count)) - } + addInPlace(&lhs, -rhs) } func subInPlace(_ lhs: inout L, _ rhs: Double) where L.Element == Double { - lhs.withUnsafeMutableMemory { lm in - var scalar = -rhs - vDSP_vsaddD(lm.pointer, numericCast(lm.stride), &scalar, lm.pointer, numericCast(lm.stride), numericCast(lm.count)) - } + addInPlace(&lhs, -rhs) } public func -=(lhs: inout L, rhs: Float) where L.Element == Float { @@ -165,25 +139,11 @@ public func -=(lhs: inout L, rhs: Double) wher // MARK: - Element-wise Subtraction public func sub(_ lhs: L, _ rhs: R) -> [Float] where L.Element == Float, R.Element == Float { - precondition(lhs.count == rhs.count, "Collections must have the same size") - var results = [Float](rhs) - lhs.withUnsafeMemory { lhsMemory in - results.withUnsafeMutableBufferPointer { bufferPointer in - catlas_saxpby(numericCast(lhsMemory.count), 1.0, lhsMemory.pointer, numericCast(lhsMemory.stride), -1, bufferPointer.baseAddress, 1) - } - } - return results + return muladd(lhs, rhs, -1.0) } public func sub(_ lhs: L, _ rhs: R) -> [Double] where L.Element == Double, R.Element == Double { - precondition(lhs.count == rhs.count, "Collections must have the same size") - var results = [Double](rhs) - lhs.withUnsafeMemory { lhsMemory in - results.withUnsafeMutableBufferPointer { bufferPointer in - catlas_daxpby(numericCast(lhsMemory.count), 1.0, lhsMemory.pointer, numericCast(lhsMemory.stride), -1, bufferPointer.baseAddress, 1) - } - } - return results + return muladd(lhs, rhs, -1.0) } public func .- (lhs: L, rhs: R) -> [Float] where L.Element == Float, R.Element == Float { @@ -197,19 +157,11 @@ public func .- (lhs: L, rh // MARK: - Element-wise Subtraction: In Place func subInPlace(_ lhs: inout L, _ rhs: R) where L.Element == Float, R.Element == Float { - lhs.withUnsafeMutableMemory { lm in - rhs.withUnsafeMemory { rm in - vDSP_vsub(rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count)) - } - } + muladdInPlace(&lhs, rhs, -1.0) } func subInPlace(_ lhs: inout L, _ rhs: R) where L.Element == Double, R.Element == Double { - lhs.withUnsafeMutableMemory { lm in - rhs.withUnsafeMemory { rm in - vDSP_vsubD(rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count)) - } - } + muladdInPlace(&lhs, rhs, -1.0) } public func .-= (lhs: inout L, rhs: R) where L.Element == Float, R.Element == Float {