Removed need for `let elementCount: Int32 = numericCast(lhs.count)` in favor of `numericCast(lm.count)`
This commit is contained in:
parent
e557d48b45
commit
8f0925a721
|
@ -198,20 +198,18 @@ func muladd<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rh
|
|||
|
||||
func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Float) where L.Element == Float, R.Element == Float {
|
||||
precondition(lhs.count == rhs.count, "Collections must have the same size")
|
||||
let elementCount: Int32 = numericCast(lhs.count)
|
||||
lhs.withUnsafeMutableMemory { lhsMemory in
|
||||
rhs.withUnsafeMemory { rhsMemory in
|
||||
cblas_saxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride))
|
||||
lhs.withUnsafeMutableMemory { lm in
|
||||
rhs.withUnsafeMemory { rm in
|
||||
cblas_saxpy(numericCast(lm.count), alpha, rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Double) where L.Element == Double, R.Element == Double {
|
||||
precondition(lhs.count == rhs.count, "Collections must have the same size")
|
||||
let elementCount: Int32 = numericCast(lhs.count)
|
||||
lhs.withUnsafeMutableMemory { lhsMemory in
|
||||
rhs.withUnsafeMemory { rhsMemory in
|
||||
cblas_daxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride))
|
||||
lhs.withUnsafeMutableMemory { lm in
|
||||
rhs.withUnsafeMemory { rm in
|
||||
cblas_daxpy(numericCast(lm.count), alpha, rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -717,16 +715,14 @@ public func sq<L: UnsafeMemoryAccessible>(_ lhs: L) -> [Double] where L.Element
|
|||
// MARK: - Square: In Place
|
||||
|
||||
func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Float {
|
||||
let elementCount: vDSP_Length = numericCast(lhs.count)
|
||||
withUnsafeMutableMemory(&lhs) { lm in
|
||||
vDSP_vsq(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), elementCount)
|
||||
vDSP_vsq(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count))
|
||||
}
|
||||
}
|
||||
|
||||
public func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Double {
|
||||
let elementCount: vDSP_Length = numericCast(lhs.count)
|
||||
withUnsafeMutableMemory(&lhs) { lm in
|
||||
vDSP_vsqD(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), elementCount)
|
||||
vDSP_vsqD(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -745,11 +741,11 @@ public func sqrt<C: UnsafeMemoryAccessible>(_ lhs: C) -> [Float] where C.Element
|
|||
///
|
||||
/// - Warning: does not support memory stride (assumes stride is 1).
|
||||
public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ lhs: MI, into results: inout MO) where MI.Element == Float, MO.Element == Float {
|
||||
return lhs.withUnsafeMemory { lhsMemory in
|
||||
return lhs.withUnsafeMemory { lm in
|
||||
results.withUnsafeMutableMemory { rm in
|
||||
precondition(lhsMemory.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
|
||||
precondition(rm.count >= lhsMemory.count, "`results` doesnt have enough capacity to store the results")
|
||||
vvsqrtf(rm.pointer, lhsMemory.pointer, [numericCast(lhsMemory.count)])
|
||||
precondition(lm.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
|
||||
precondition(rm.count >= lm.count, "`results` doesnt have enough capacity to store the results")
|
||||
vvsqrtf(rm.pointer, lm.pointer, [numericCast(lm.count)])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -767,11 +763,11 @@ public func sqrt<C: UnsafeMemoryAccessible>(_ lhs: C) -> [Double] where C.Elemen
|
|||
///
|
||||
/// - Warning: does not support memory stride (assumes stride is 1).
|
||||
public func sqrt<MI: UnsafeMemoryAccessible, MO: UnsafeMutableMemoryAccessible>(_ lhs: MI, into results: inout MO) where MI.Element == Double, MO.Element == Double {
|
||||
return lhs.withUnsafeMemory { lhsMemory in
|
||||
return lhs.withUnsafeMemory { lm in
|
||||
results.withUnsafeMutableMemory { rm in
|
||||
precondition(lhsMemory.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
|
||||
precondition(rm.count >= lhsMemory.count, "`results` doesnt have enough capacity to store the results")
|
||||
vvsqrt(rm.pointer, lhsMemory.pointer, [numericCast(lhsMemory.count)])
|
||||
precondition(lm.stride == 1 && rm.stride == 1, "sqrt doesn't support step values other than 1")
|
||||
precondition(rm.count >= lm.count, "`results` doesnt have enough capacity to store the results")
|
||||
vvsqrt(rm.pointer, lm.pointer, [numericCast(lm.count)])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -803,27 +799,23 @@ func sqrtInPlace<C: UnsafeMutableMemoryAccessible>(_ lhs: inout C) where C.Eleme
|
|||
// MARK: - Dot Product
|
||||
|
||||
public func dot<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> Float where L.Element == Float, R.Element == Float {
|
||||
return withUnsafeMemory(lhs, rhs) { lhsMemory, rhsMemory in
|
||||
precondition(lhsMemory.count == rhsMemory.count, "Vectors must have equal count")
|
||||
|
||||
return withUnsafeMemory(lhs, rhs) { lm, rm in
|
||||
precondition(lm.count == rm.count, "Vectors must have equal count")
|
||||
var result: Float = 0.0
|
||||
withUnsafeMutablePointer(to: &result) { pointer in
|
||||
vDSP_dotpr(lhsMemory.pointer, numericCast(lhsMemory.stride), rhsMemory.pointer, numericCast(rhsMemory.stride), pointer, numericCast(lhsMemory.count))
|
||||
vDSP_dotpr(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), pointer, numericCast(lm.count))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
public func dot<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> Double where L.Element == Double, R.Element == Double {
|
||||
return withUnsafeMemory(lhs, rhs) { lhsMemory, rhsMemory in
|
||||
precondition(lhsMemory.count == rhsMemory.count, "Vectors must have equal count")
|
||||
|
||||
return withUnsafeMemory(lhs, rhs) { lm, rm in
|
||||
precondition(lm.count == rm.count, "Vectors must have equal count")
|
||||
var result: Double = 0.0
|
||||
withUnsafeMutablePointer(to: &result) { pointer in
|
||||
vDSP_dotprD(lhsMemory.pointer, numericCast(lhsMemory.stride), rhsMemory.pointer, numericCast(rhsMemory.stride), pointer, numericCast(lhsMemory.count))
|
||||
vDSP_dotprD(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), pointer, numericCast(lm.count))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue