Turned `add…`/`sub…` into thin wrappers of `muladd…` (in ‘Arithmetic.swift’)

This commit is contained in:
Vincent Esche 2019-09-24 14:09:04 +02:00
parent 99efd386f7
commit 88e8cb3ef1
1 changed files with 14 additions and 62 deletions

View File

@ -41,15 +41,17 @@ public func + <L: UnsafeMemoryAccessible>(lhs: L, rhs: Double) -> [Double] where
// MARK: - Addition: In Place
func addInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L, _ rhs: Float) where L.Element == Float {
lhs.withUnsafeMutableMemory { lm in
var scalar = rhs
lhs.withUnsafeMutableMemory { lm in
vDSP_vsadd(lm.pointer, numericCast(lm.stride), &scalar, lm.pointer, numericCast(lm.stride), numericCast(lm.count))
}
}
func addInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L, _ rhs: Double) where L.Element == Double {
lhs.withUnsafeMutableMemory { lm in
var scalar = rhs
lhs.withUnsafeMutableMemory { lm in
vDSP_vsaddD(lm.pointer, numericCast(lm.stride), &scalar, lm.pointer, numericCast(lm.stride), numericCast(lm.count))
}
}
@ -65,25 +67,11 @@ public func +=<L: UnsafeMutableMemoryAccessible>(lhs: inout L, rhs: Double) wher
// MARK: - Element-wise Addition
public func add<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> [Float] where L.Element == Float, R.Element == Float {
precondition(lhs.count == rhs.count, "Collections must have the same size")
var results = [Float](rhs)
lhs.withUnsafeMemory { lhsMemory in
results.withUnsafeMutableBufferPointer { bufferPointer in
cblas_saxpy(numericCast(lhsMemory.count), 1.0, lhsMemory.pointer, numericCast(lhsMemory.stride), bufferPointer.baseAddress, 1)
}
}
return results
muladd(lhs, rhs, 1.0)
}
public func add<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> [Double] where L.Element == Double, R.Element == Double {
precondition(lhs.count == rhs.count, "Collections must have the same size")
var results = [Double](rhs)
lhs.withUnsafeMemory { lhsMemory in
results.withUnsafeMutableBufferPointer { bufferPointer in
cblas_daxpy(numericCast(lhsMemory.count), 1.0, lhsMemory.pointer, numericCast(lhsMemory.stride), bufferPointer.baseAddress, 1)
}
}
return results
muladd(lhs, rhs, 1.0)
}
public func .+ <L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(lhs: L, rhs: R) -> [Float] where L.Element == Float, R.Element == Float {
@ -97,19 +85,11 @@ public func .+ <L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(lhs: L, rh
// MARK: - Element-wise Addition: In Place
func addInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R) where L.Element == Float, R.Element == Float {
lhs.withUnsafeMutableMemory { lm in
rhs.withUnsafeMemory { rm in
vDSP_vadd(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count))
}
}
muladdInPlace(&lhs, rhs, 1.0)
}
func addInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R) where L.Element == Double, R.Element == Double {
lhs.withUnsafeMutableMemory { lm in
rhs.withUnsafeMemory { rm in
vDSP_vaddD(lm.pointer, numericCast(lm.stride), rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count))
}
}
muladdInPlace(&lhs, rhs, 1.0)
}
public func .+= <L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(lhs: inout L, rhs: R) where L.Element == Float, R.Element == Float {
@ -141,17 +121,11 @@ public func - <L: UnsafeMemoryAccessible>(lhs: L, rhs: Double) -> [Double] where
// MARK: - Subtraction: In Place
func subInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L, _ rhs: Float) where L.Element == Float {
lhs.withUnsafeMutableMemory { lm in
var scalar = -rhs
vDSP_vsadd(lm.pointer, numericCast(lm.stride), &scalar, lm.pointer, numericCast(lm.stride), numericCast(lm.count))
}
addInPlace(&lhs, -rhs)
}
func subInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L, _ rhs: Double) where L.Element == Double {
lhs.withUnsafeMutableMemory { lm in
var scalar = -rhs
vDSP_vsaddD(lm.pointer, numericCast(lm.stride), &scalar, lm.pointer, numericCast(lm.stride), numericCast(lm.count))
}
addInPlace(&lhs, -rhs)
}
public func -=<L: UnsafeMutableMemoryAccessible>(lhs: inout L, rhs: Float) where L.Element == Float {
@ -165,25 +139,11 @@ public func -=<L: UnsafeMutableMemoryAccessible>(lhs: inout L, rhs: Double) wher
// MARK: - Element-wise Subtraction
public func sub<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> [Float] where L.Element == Float, R.Element == Float {
precondition(lhs.count == rhs.count, "Collections must have the same size")
var results = [Float](rhs)
lhs.withUnsafeMemory { lhsMemory in
results.withUnsafeMutableBufferPointer { bufferPointer in
catlas_saxpby(numericCast(lhsMemory.count), 1.0, lhsMemory.pointer, numericCast(lhsMemory.stride), -1, bufferPointer.baseAddress, 1)
}
}
return results
return muladd(lhs, rhs, -1.0)
}
public func sub<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R) -> [Double] where L.Element == Double, R.Element == Double {
precondition(lhs.count == rhs.count, "Collections must have the same size")
var results = [Double](rhs)
lhs.withUnsafeMemory { lhsMemory in
results.withUnsafeMutableBufferPointer { bufferPointer in
catlas_daxpby(numericCast(lhsMemory.count), 1.0, lhsMemory.pointer, numericCast(lhsMemory.stride), -1, bufferPointer.baseAddress, 1)
}
}
return results
return muladd(lhs, rhs, -1.0)
}
public func .- <L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(lhs: L, rhs: R) -> [Float] where L.Element == Float, R.Element == Float {
@ -197,19 +157,11 @@ public func .- <L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(lhs: L, rh
// MARK: - Element-wise Subtraction: In Place
func subInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R) where L.Element == Float, R.Element == Float {
lhs.withUnsafeMutableMemory { lm in
rhs.withUnsafeMemory { rm in
vDSP_vsub(rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count))
}
}
muladdInPlace(&lhs, rhs, -1.0)
}
func subInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R) where L.Element == Double, R.Element == Double {
lhs.withUnsafeMutableMemory { lm in
rhs.withUnsafeMemory { rm in
vDSP_vsubD(rm.pointer, numericCast(rm.stride), lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), numericCast(lm.count))
}
}
muladdInPlace(&lhs, rhs, -1.0)
}
public func .-= <L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(lhs: inout L, rhs: R) where L.Element == Float, R.Element == Float {