Removed need for `let elementCount: Int32 = numericCast(lhs.count)` in favor of `numericCast(lm.count)`
This commit is contained in:
parent
e557d48b45
commit
8f0925a721
|
@ -198,20 +198,18 @@ func muladd<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rh
|
||||||
|
|
||||||
func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Float) where L.Element == Float, R.Element == Float {
|
func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Float) where L.Element == Float, R.Element == Float {
|
||||||
precondition(lhs.count == rhs.count, "Collections must have the same size")
|
precondition(lhs.count == rhs.count, "Collections must have the same size")
|
||||||
let elementCount: Int32 = numericCast(lhs.count)
|
lhs.withUnsafeMutableMemory { lm in
|
||||||
lhs.withUnsafeMutableMemory { lhsMemory in
|
rhs.withUnsafeMemory { rm in
|
||||||
rhs.withUnsafeMemory { rhsMemory in
|
cblas_saxpy(numericCast(lm.count), alpha, rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride))
|
||||||
cblas_saxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Double) where L.Element == Double, R.Element == Double {
|
func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Double) where L.Element == Double, R.Element == Double {
|
||||||
precondition(lhs.count == rhs.count, "Collections must have the same size")
|
precondition(lhs.count == rhs.count, "Collections must have the same size")
|
||||||
let elementCount: Int32 = numericCast(lhs.count)
|
lhs.withUnsafeMutableMemory { lm in
|
||||||
lhs.withUnsafeMutableMemory { lhsMemory in
|
rhs.withUnsafeMemory { rm in
|
||||||
rhs.withUnsafeMemory { rhsMemory in
|
cblas_daxpy(numericCast(lm.count), alpha, rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride))
|
||||||
cblas_daxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -717,16 +715,14 @@ public func sq<L: UnsafeMemoryAccessible>(_ lhs: L) -> [Double] where L.Element
|
||||||
// MARK: - Square: In Place
|
// MARK: - Square: In Place
|
||||||
|
|
||||||
func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Float {
|
func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Float {
|
||||||
let elementCount: vDSP_Length = numericCast(lhs.count)
|
|
||||||
withUnsafeMutableMemory(&lhs) { lm in
|
withUnsafeMutableMemory(&lhs) { lm in
|
||||||
vDSP_vsq(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), elementCount)
|
vDSP_vsq(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Double {
|
public func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Double {
|
||||||
let elementCount: vDSP_Length = numericCast(lhs.count)
|
|
||||||
withUnsafeMutableMemory(&lhs) { lm in
|
withUnsafeMutableMemory(&lhs) { lm in
|
||||||
vDSP_vsqD(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), elementCount)
|
vDSP_vsqD(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -745,11 +741,11 @@ public func sqrt<C: UnsafeMemoryAccessible>(_ lhs: C) -> [Float] where C.Element
|
||||||
///
|
///
|
||||||
/// - Warning: does not support memory stride (assumes stride is 1).
|
/// - Warning: does not support memory stride (assumes stride is 1).
|
||||||
public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ lhs: MI, into results: inout MO) where MI.Element == Float, MO.Element == Float {
|
public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ lhs: MI, into results: inout MO) where MI.Element == Float, MO.Element == Float {
|
||||||
return lhs.withUnsafeMemory { lhsMemory in
|
return lhs.withUnsafeMemory { lm in
|
||||||
results.withUnsafeMutableMemory { rm in
|
results.withUnsafeMutableMemory { rm in
|
||||||
precondition(lhsMemory.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
|
precondition(lm.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
|
||||||
precondition(rm.count >= lhsMemory.count, "`results` doesnt have enough capacity to store the results")
|
precondition(rm.count >= lm.count, "`results` doesnt have enough capacity to store the results")
|
||||||
vvsqrtf(rm.pointer, lhsMemory.pointer, [numericCast(lhsMemory.count)])
|
vvsqrtf(rm.pointer, lm.pointer, [numericCast(lm.count)])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -767,11 +763,11 @@ public func sqrt<C: UnsafeMemoryAccessible>(_ lhs: C) -> [Double] where C.Elemen
|
||||||
///
|
///
|
||||||
/// - Warning: does not support memory stride (assumes stride is 1).
|
/// - Warning: does not support memory stride (assumes stride is 1).
|
||||||
public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ lhs: MI, into results: inout MO) where MI.Element == Double, MO.Element == Double {
|
public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ lhs: MI, into results: inout MO) where MI.Element == Double, MO.Element == Double {
|
||||||
return lhs.withUnsafeMemory { lhsMemory in
|
return lhs.withUnsafeMemory { lm in
|
||||||
results.withUnsafeMutableMemory { rm in
|
results.withUnsafeMutableMemory { rm in
|
||||||
precondition(lhsMemory.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
|
precondition(lm.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
|
||||||
precondition(rm.count >= lhsMemory.count, "`results` doesnt have enough capacity to store the results")
|
precondition(rm.count >= lm.count, "`results` doesnt have enough capacity to store the results")
|
||||||
vvsqrt(rm.pointer, lhsMemory.pointer, [numericCast(lhsMemory.count)])
|
vvsqrt(rm.pointer, lm.pointer, [numericCast(lm.count)])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -803,27 +799,23 @@ func sqrtInPlace<C: UnsafeMutableMemoryAccessible>(_ lhs: inout C) where C.Eleme
|
||||||
// MARK: - Dot Product
|
// MARK: - Dot Product
|
||||||
|
|
||||||
public func dot<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> Float where L.Element == Float, R.Element == Float {
|
public func dot<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> Float where L.Element == Float, R.Element == Float {
|
||||||
return withUnsafeMemory(lhs, rhs) { lhsMemory, rhsMemory in
|
return withUnsafeMemory(lhs, rhs) { lm, rm in
|
||||||
precondition(lhsMemory.count == rhsMemory.count, "Vectors must have equal count")
|
precondition(lm.count == rm.count, "Vectors must have equal count")
|
||||||
|
|
||||||
var result: Float = 0.0
|
var result: Float = 0.0
|
||||||
withUnsafeMutablePointer(to: &result) { pointer in
|
withUnsafeMutablePointer(to: &result) { pointer in
|
||||||
vDSP_dotpr(lhsMemory.pointer, numericCast(lhsMemory.stride), rhsMemory.pointer, numericCast(rhsMemory.stride), pointer, numericCast(lhsMemory.count))
|
vDSP_dotpr(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), pointer, numericCast(lm.count))
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public func dot<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> Double where L.Element == Double, R.Element == Double {
|
public func dot<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> Double where L.Element == Double, R.Element == Double {
|
||||||
return withUnsafeMemory(lhs, rhs) { lhsMemory, rhsMemory in
|
return withUnsafeMemory(lhs, rhs) { lm, rm in
|
||||||
precondition(lhsMemory.count == rhsMemory.count, "Vectors must have equal count")
|
precondition(lm.count == rm.count, "Vectors must have equal count")
|
||||||
|
|
||||||
var result: Double = 0.0
|
var result: Double = 0.0
|
||||||
withUnsafeMutablePointer(to: &result) { pointer in
|
withUnsafeMutablePointer(to: &result) { pointer in
|
||||||
vDSP_dotprD(lhsMemory.pointer, numericCast(lhsMemory.stride), rhsMemory.pointer, numericCast(rhsMemory.stride), pointer, numericCast(lhsMemory.count))
|
vDSP_dotprD(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), pointer, numericCast(lm.count))
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue