Removed need for `let elementCount: Int32 = numericCast(lhs.count)` in favor of `numericCast(lm.count)`

This commit is contained in:
Vincent Esche 2019-09-24 17:06:51 +02:00
parent e557d48b45
commit 8f0925a721
1 changed files with 22 additions and 30 deletions

View File

@ -198,20 +198,18 @@ func muladd<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rh
func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Float) where L.Element == Float, R.Element == Float {
precondition(lhs.count == rhs.count, "Collections must have the same size")
let elementCount: Int32 = numericCast(lhs.count)
lhs.withUnsafeMutableMemory { lhsMemory in
rhs.withUnsafeMemory { rhsMemory in
cblas_saxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride))
lhs.withUnsafeMutableMemory { lm in
rhs.withUnsafeMemory { rm in
cblas_saxpy(numericCast(lm.count), alpha, rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride))
}
}
}
func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Double) where L.Element == Double, R.Element == Double {
precondition(lhs.count == rhs.count, "Collections must have the same size")
let elementCount: Int32 = numericCast(lhs.count)
lhs.withUnsafeMutableMemory { lhsMemory in
rhs.withUnsafeMemory { rhsMemory in
cblas_daxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride))
lhs.withUnsafeMutableMemory { lm in
rhs.withUnsafeMemory { rm in
cblas_daxpy(numericCast(lm.count), alpha, rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride))
}
}
}
@ -717,16 +715,14 @@ public func sq<L: UnsafeMemoryAccessible>(_ lhs: L) -> [Double] where L.Element
// MARK: - Square: In Place
func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Float {
let elementCount: vDSP_Length = numericCast(lhs.count)
withUnsafeMutableMemory(&lhs) { lm in
vDSP_vsq(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), elementCount)
vDSP_vsq(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count))
}
}
public func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Double {
let elementCount: vDSP_Length = numericCast(lhs.count)
withUnsafeMutableMemory(&lhs) { lm in
vDSP_vsqD(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), elementCount)
vDSP_vsqD(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count))
}
}
@ -745,11 +741,11 @@ public func sqrt<C: UnsafeMemoryAccessible>(_ lhs: C) -> [Float] where C.Element
///
/// - Warning: does not support memory stride (assumes stride is 1).
public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ lhs: MI, into results: inout MO) where MI.Element == Float, MO.Element == Float {
return lhs.withUnsafeMemory { lhsMemory in
return lhs.withUnsafeMemory { lm in
results.withUnsafeMutableMemory { rm in
precondition(lhsMemory.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
precondition(rm.count >= lhsMemory.count, "`results` doesnt have enough capacity to store the results")
vvsqrtf(rm.pointer, lhsMemory.pointer, [numericCast(lhsMemory.count)])
precondition(lm.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
precondition(rm.count >= lm.count, "`results` doesnt have enough capacity to store the results")
vvsqrtf(rm.pointer, lm.pointer, [numericCast(lm.count)])
}
}
}
@ -767,11 +763,11 @@ public func sqrt<C: UnsafeMemoryAccessible>(_ lhs: C) -> [Double] where C.Elemen
///
/// - Warning: does not support memory stride (assumes stride is 1).
public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ lhs: MI, into results: inout MO) where MI.Element == Double, MO.Element == Double {
return lhs.withUnsafeMemory { lhsMemory in
return lhs.withUnsafeMemory { lm in
results.withUnsafeMutableMemory { rm in
precondition(lhsMemory.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
precondition(rm.count >= lhsMemory.count, "`results` doesnt have enough capacity to store the results")
vvsqrt(rm.pointer, lhsMemory.pointer, [numericCast(lhsMemory.count)])
precondition(lm.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
precondition(rm.count >= lm.count, "`results` doesnt have enough capacity to store the results")
vvsqrt(rm.pointer, lm.pointer, [numericCast(lm.count)])
}
}
}
@ -803,27 +799,23 @@ func sqrtInPlace<C: UnsafeMutableMemoryAccessible>(_ lhs: inout C) where C.Eleme
// MARK: - Dot Product
public func dot<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> Float where L.Element == Float, R.Element == Float {
return withUnsafeMemory(lhs, rhs) { lhsMemory, rhsMemory in
precondition(lhsMemory.count == rhsMemory.count, "Vectors must have equal count")
return withUnsafeMemory(lhs, rhs) { lm, rm in
precondition(lm.count == rm.count, "Vectors must have equal count")
var result: Float = 0.0
withUnsafeMutablePointer(to: &result) { pointer in
vDSP_dotpr(lhsMemory.pointer, numericCast(lhsMemory.stride), rhsMemory.pointer, numericCast(rhsMemory.stride), pointer, numericCast(lhsMemory.count))
vDSP_dotpr(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), pointer, numericCast(lm.count))
}
return result
}
}
public func dot<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> Double where L.Element == Double, R.Element == Double {
return withUnsafeMemory(lhs, rhs) { lhsMemory, rhsMemory in
precondition(lhsMemory.count == rhsMemory.count, "Vectors must have equal count")
return withUnsafeMemory(lhs, rhs) { lm, rm in
precondition(lm.count == rm.count, "Vectors must have equal count")
var result: Double = 0.0
withUnsafeMutablePointer(to: &result) { pointer in
vDSP_dotprD(lhsMemory.pointer, numericCast(lhsMemory.stride), rhsMemory.pointer, numericCast(rhsMemory.stride), pointer, numericCast(lhsMemory.count))
vDSP_dotprD(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), pointer, numericCast(lm.count))
}
return result
}
}