Added `muladd`/`muladdInPlace` to ‘Arithmetic.swift’
This commit is contained in:
parent
e243b08fd0
commit
99efd386f7
|
@ -220,6 +220,42 @@ public func .-= <L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(lh
|
|||
return subInPlace(&lhs, rhs)
|
||||
}
|
||||
|
||||
// MARK: - Multiply Addition
|
||||
|
||||
func muladd<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R, _ alpha: Float) -> [Float] where L.Element == Float, R.Element == Float {
|
||||
var results = Array(lhs)
|
||||
muladdInPlace(&results, rhs, alpha)
|
||||
return results
|
||||
}
|
||||
|
||||
func muladd<L: UnsafeMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: R, _ alpha: Double) -> [Double] where L.Element == Double, R.Element == Double {
|
||||
var results = Array(lhs)
|
||||
muladdInPlace(&results, rhs, alpha)
|
||||
return results
|
||||
}
|
||||
|
||||
// MARK: - Multiply Addition: In Place
|
||||
|
||||
func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Float) where L.Element == Float, R.Element == Float {
|
||||
precondition(lhs.count == rhs.count, "Collections must have the same size")
|
||||
let elementCount: Int32 = numericCast(lhs.count)
|
||||
lhs.withUnsafeMutableMemory { lhsMemory in
|
||||
rhs.withUnsafeMemory { rhsMemory in
|
||||
cblas_saxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func muladdInPlace<L: UnsafeMutableMemoryAccessible, R: UnsafeMemoryAccessible>(_ lhs: inout L, _ rhs: R, _ alpha: Double) where L.Element == Double, R.Element == Double {
|
||||
precondition(lhs.count == rhs.count, "Collections must have the same size")
|
||||
let elementCount: Int32 = numericCast(lhs.count)
|
||||
lhs.withUnsafeMutableMemory { lhsMemory in
|
||||
rhs.withUnsafeMemory { rhsMemory in
|
||||
cblas_daxpy(elementCount, alpha, rhsMemory.pointer, numericCast(rhsMemory.stride), lhsMemory.pointer, numericCast(lhsMemory.stride))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Multiplication
|
||||
|
||||
public func mul<L: UnsafeMemoryAccessible>(_ lhs: L, _ rhs: Float) -> [Float] where L.Element == Float {
|
||||
|
|
Loading…
Reference in New Issue