diff --git a/Sources/Surge/Arithmetic/Arithmetic.swift b/Sources/Surge/Arithmetic/Arithmetic.swift index a4269d7..460a091 100644 --- a/Sources/Surge/Arithmetic/Arithmetic.swift +++ b/Sources/Surge/Arithmetic/Arithmetic.swift @@ -198,20 +198,18 @@ func muladd(_ lhs: L, _ rh func muladdInPlace(_ lhs: inout L, _ rhs: R, _ alpha: Float) where L.Element == Float, R.Element == Float { precondition(lhs.count == rhs.count, "Collections must have the same size") - let elementCount: Int32 = numericCast(lhs.count) - lhs.withUnsafeMutableMemory { lhsMemory in - rhs.withUnsafeMemory { rhsMemory in - cblas_saxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride)) + lhs.withUnsafeMutableMemory { lm in + rhs.withUnsafeMemory { rm in + cblas_saxpy(numericCast(lm.count), alpha, rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride)) } } } func muladdInPlace(_ lhs: inout L, _ rhs: R, _ alpha: Double) where L.Element == Double, R.Element == Double { precondition(lhs.count == rhs.count, "Collections must have the same size") - let elementCount: Int32 = numericCast(lhs.count) - lhs.withUnsafeMutableMemory { lhsMemory in - rhs.withUnsafeMemory { rhsMemory in - cblas_daxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride)) + lhs.withUnsafeMutableMemory { lm in + rhs.withUnsafeMemory { rm in + cblas_daxpy(numericCast(lm.count), alpha, rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride)) } } } @@ -717,16 +715,14 @@ public func sq(_ lhs: L) -> [Double] where L.Element // MARK: - Square: In Place func sqInPlace(_ lhs: inout L) where L.Element == Float { - let elementCount: vDSP_Length = numericCast(lhs.count) withUnsafeMutableMemory(&lhs) { lm in - vDSP_vsq(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), elementCount) + vDSP_vsq(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count)) } } public func sqInPlace(_ lhs: inout L) where L.Element == Double { - let elementCount: vDSP_Length = numericCast(lhs.count) withUnsafeMutableMemory(&lhs) { lm in - vDSP_vsqD(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), elementCount) + vDSP_vsqD(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count)) } } @@ -745,11 +741,11 @@ public func sqrt(_ lhs: C) -> [Float] where C.Element /// /// - Warning: does not support memory stride (assumes stride is 1). public func sqrt(_ lhs: MI, into results: inout MO) where MI.Element == Float, MO.Element == Float { - return lhs.withUnsafeMemory { lhsMemory in + return lhs.withUnsafeMemory { lm in results.withUnsafeMutableMemory { rm in - precondition(lhsMemory.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1") - precondition(rm.count >= lhsMemory.count, "`results` doesnt have enough capacity to store the results") - vvsqrtf(rm.pointer, lhsMemory.pointer, [numericCast(lhsMemory.count)]) + precondition(lm.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1") + precondition(rm.count >= lm.count, "`results` doesnt have enough capacity to store the results") + vvsqrtf(rm.pointer, lm.pointer, [numericCast(lm.count)]) } } } @@ -767,11 +763,11 @@ public func sqrt(_ lhs: C) -> [Double] where C.Elemen /// /// - Warning: does not support memory stride (assumes stride is 1). public func sqrt(_ lhs: MI, into results: inout MO) where MI.Element == Double, MO.Element == Double { - return lhs.withUnsafeMemory { lhsMemory in + return lhs.withUnsafeMemory { lm in results.withUnsafeMutableMemory { rm in - precondition(lhsMemory.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1") - precondition(rm.count >= lhsMemory.count, "`results` doesnt have enough capacity to store the results") - vvsqrt(rm.pointer, lhsMemory.pointer, [numericCast(lhsMemory.count)]) + precondition(lm.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1") + precondition(rm.count >= lm.count, "`results` doesnt have enough capacity to store the results") + vvsqrt(rm.pointer, lm.pointer, [numericCast(lm.count)]) } } } @@ -803,27 +799,23 @@ func sqrtInPlace(_ lhs: inout C) where C.Eleme // MARK: - Dot Product public func dot(_ lhs: L, _ rhs: R) -> Float where L.Element == Float, R.Element == Float { - return withUnsafeMemory(lhs, rhs) { lhsMemory, rhsMemory in - precondition(lhsMemory.count == rhsMemory.count, "Vectors must have equal count") - + return withUnsafeMemory(lhs, rhs) { lm, rm in + precondition(lm.count == rm.count, "Vectors must have equal count") var result: Float = 0.0 withUnsafeMutablePointer(to: &result) { pointer in - vDSP_dotpr(lhsMemory.pointer, numericCast(lhsMemory.stride), rhsMemory.pointer, numericCast(rhsMemory.stride), pointer, numericCast(lhsMemory.count)) + vDSP_dotpr(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), pointer, numericCast(lm.count)) } - return result } } public func dot(_ lhs: L, _ rhs: R) -> Double where L.Element == Double, R.Element == Double { - return withUnsafeMemory(lhs, rhs) { lhsMemory, rhsMemory in - precondition(lhsMemory.count == rhsMemory.count, "Vectors must have equal count") - + return withUnsafeMemory(lhs, rhs) { lm, rm in + precondition(lm.count == rm.count, "Vectors must have equal count") var result: Double = 0.0 withUnsafeMutablePointer(to: &result) { pointer in - vDSP_dotprD(lhsMemory.pointer, numericCast(lhsMemory.stride), rhsMemory.pointer, numericCast(rhsMemory.stride), pointer, numericCast(lhsMemory.count)) + vDSP_dotprD(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), pointer, numericCast(lm.count)) } - return result } }