Removed need for `let elementCount: Int32 = numericCast(lhs.count)` in favor of `numericCast(lm.count)`

This commit is contained in:
Vincent Esche 2019-09-24 17:06:51 +02:00
parent e557d48b45
commit 8f0925a721
1 changed files with 22 additions and 30 deletions

View File

@ -198,20 +198,18 @@ func muladd<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rh
func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Float) where L.Element == Float, R.Element == Float { func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Float) where L.Element == Float, R.Element == Float {
precondition(lhs.count == rhs.count, "Collections must have the same size") precondition(lhs.count == rhs.count, "Collections must have the same size")
let elementCount: Int32 = numericCast(lhs.count) lhs.withUnsafeMutableMemory { lm in
lhs.withUnsafeMutableMemory { lhsMemory in rhs.withUnsafeMemory { rm in
rhs.withUnsafeMemory { rhsMemory in cblas_saxpy(numericCast(lm.count), alpha, rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride))
cblas_saxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride))
} }
} }
} }
func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Double) where L.Element == Double, R.Element == Double { func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Double) where L.Element == Double, R.Element == Double {
precondition(lhs.count == rhs.count, "Collections must have the same size") precondition(lhs.count == rhs.count, "Collections must have the same size")
let elementCount: Int32 = numericCast(lhs.count) lhs.withUnsafeMutableMemory { lm in
lhs.withUnsafeMutableMemory { lhsMemory in rhs.withUnsafeMemory { rm in
rhs.withUnsafeMemory { rhsMemory in cblas_daxpy(numericCast(lm.count), alpha, rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride))
cblas_daxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride))
} }
} }
} }
@ -717,16 +715,14 @@ public func sq<L: UnsafeMemoryAccessible>(_ lhs: L) -> [Double] where L.Element
// MARK: - Square: In Place // MARK: - Square: In Place
func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Float { func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Float {
let elementCount: vDSP_Length = numericCast(lhs.count)
withUnsafeMutableMemory(&lhs) { lm in withUnsafeMutableMemory(&lhs) { lm in
vDSP_vsq(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), elementCount) vDSP_vsq(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count))
} }
} }
public func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Double { public func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Double {
let elementCount: vDSP_Length = numericCast(lhs.count)
withUnsafeMutableMemory(&lhs) { lm in withUnsafeMutableMemory(&lhs) { lm in
vDSP_vsqD(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), elementCount) vDSP_vsqD(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count))
} }
} }
@ -745,11 +741,11 @@ public func sqrt<C: UnsafeMemoryAccessible>(_ lhs: C) -> [Float] where C.Element
/// ///
/// - Warning: does not support memory stride (assumes stride is 1). /// - Warning: does not support memory stride (assumes stride is 1).
public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ lhs: MI, into results: inout MO) where MI.Element == Float, MO.Element == Float { public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ lhs: MI, into results: inout MO) where MI.Element == Float, MO.Element == Float {
return lhs.withUnsafeMemory { lhsMemory in return lhs.withUnsafeMemory { lm in
results.withUnsafeMutableMemory { rm in results.withUnsafeMutableMemory { rm in
precondition(lhsMemory.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1") precondition(lm.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
precondition(rm.count >= lhsMemory.count, "`results` doesnt have enough capacity to store the results") precondition(rm.count >= lm.count, "`results` doesnt have enough capacity to store the results")
vvsqrtf(rm.pointer, lhsMemory.pointer, [numericCast(lhsMemory.count)]) vvsqrtf(rm.pointer, lm.pointer, [numericCast(lm.count)])
} }
} }
} }
@ -767,11 +763,11 @@ public func sqrt<C: UnsafeMemoryAccessible>(_ lhs: C) -> [Double] where C.Elemen
/// ///
/// - Warning: does not support memory stride (assumes stride is 1). /// - Warning: does not support memory stride (assumes stride is 1).
public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ lhs: MI, into results: inout MO) where MI.Element == Double, MO.Element == Double { public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ lhs: MI, into results: inout MO) where MI.Element == Double, MO.Element == Double {
return lhs.withUnsafeMemory { lhsMemory in return lhs.withUnsafeMemory { lm in
results.withUnsafeMutableMemory { rm in results.withUnsafeMutableMemory { rm in
precondition(lhsMemory.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1") precondition(lm.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
precondition(rm.count >= lhsMemory.count, "`results` doesnt have enough capacity to store the results") precondition(rm.count >= lm.count, "`results` doesnt have enough capacity to store the results")
vvsqrt(rm.pointer, lhsMemory.pointer, [numericCast(lhsMemory.count)]) vvsqrt(rm.pointer, lm.pointer, [numericCast(lm.count)])
} }
} }
} }
@ -803,27 +799,23 @@ func sqrtInPlace<C: UnsafeMutableMemoryAccessible>(_ lhs: inout C) where C.Eleme
// MARK: - Dot Product // MARK: - Dot Product
public func dot<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> Float where L.Element == Float, R.Element == Float { public func dot<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> Float where L.Element == Float, R.Element == Float {
return withUnsafeMemory(lhs, rhs) { lhsMemory, rhsMemory in return withUnsafeMemory(lhs, rhs) { lm, rm in
precondition(lhsMemory.count == rhsMemory.count, "Vectors must have equal count") precondition(lm.count == rm.count, "Vectors must have equal count")
var result: Float = 0.0 var result: Float = 0.0
withUnsafeMutablePointer(to: &result) { pointer in withUnsafeMutablePointer(to: &result) { pointer in
vDSP_dotpr(lhsMemory.pointer, numericCast(lhsMemory.stride), rhsMemory.pointer, numericCast(rhsMemory.stride), pointer, numericCast(lhsMemory.count)) vDSP_dotpr(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), pointer, numericCast(lm.count))
} }
return result return result
} }
} }
public func dot<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> Double where L.Element == Double, R.Element == Double { public func dot<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> Double where L.Element == Double, R.Element == Double {
return withUnsafeMemory(lhs, rhs) { lhsMemory, rhsMemory in return withUnsafeMemory(lhs, rhs) { lm, rm in
precondition(lhsMemory.count == rhsMemory.count, "Vectors must have equal count") precondition(lm.count == rm.count, "Vectors must have equal count")
var result: Double = 0.0 var result: Double = 0.0
withUnsafeMutablePointer(to: &result) { pointer in withUnsafeMutablePointer(to: &result) { pointer in
vDSP_dotprD(lhsMemory.pointer, numericCast(lhsMemory.stride), rhsMemory.pointer, numericCast(rhsMemory.stride), pointer, numericCast(lhsMemory.count)) vDSP_dotprD(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), pointer, numericCast(lm.count))
} }
return result return result
} }
} }