Added `Matrix += Matrix` and turned `Matrix + Matrix` into a shallow wrapper around it

This commit is contained in:
Vincent Esche 2019-09-23 13:45:42 +02:00
parent d3f7a50ad9
commit 1c274b3ee7
1 changed files with 41 additions and 5 deletions

View File

@ -269,12 +269,10 @@ public func ==<T> (lhs: Matrix<T>, rhs: Matrix<T>) -> Bool {
public func add(_ lhs: Matrix<Float>, _ rhs: Matrix<Float>) -> Matrix<Float> {
precondition(lhs.rows == rhs.rows && lhs.columns == rhs.columns, "Matrix dimensions not compatible with addition")
var results = rhs
results.grid.withUnsafeMutableBufferPointer { pointer in
cblas_saxpy(Int32(lhs.grid.count), 1.0, lhs.grid, 1, pointer.baseAddress!, 1)
}
var result = lhs
result += rhs
return results
return result
}
public func add(_ lhs: Matrix<Double>, _ rhs: Matrix<Double>) -> Matrix<Double> {
@ -296,6 +294,44 @@ public func + (lhs: Matrix<Double>, rhs: Matrix<Double>) -> Matrix<Double> {
return add(lhs, rhs)
}
// MARK: - Addition: In Place
func addInPlace(_ lhs: inout Matrix<Float>, _ rhs: Matrix<Float>) {
precondition(lhs.rows == rhs.rows && lhs.columns == rhs.columns, "Matrix dimensions not compatible with addition")
let gridSize = Int32(lhs.grid.count)
let alpha: Float = 1.0
let stride: Int32 = 1
lhs.grid.withUnsafeMutableBufferPointer { lhsPointer in
rhs.grid.withUnsafeBufferPointer { rhsPointer in
cblas_saxpy(gridSize, alpha, rhsPointer.baseAddress!, stride, lhsPointer.baseAddress!, stride)
}
}
}
func addInPlace(_ lhs: inout Matrix<Double>, _ rhs: Matrix<Double>) {
precondition(lhs.rows == rhs.rows && lhs.columns == rhs.columns, "Matrix dimensions not compatible with addition")
let gridSize = Int32(lhs.grid.count)
let alpha: Double = 1.0
let stride: Int32 = 1
lhs.grid.withUnsafeMutableBufferPointer { lhsPointer in
rhs.grid.withUnsafeBufferPointer { rhsPointer in
cblas_daxpy(gridSize, alpha, rhsPointer.baseAddress!, stride, lhsPointer.baseAddress!, stride)
}
}
}
public func += (lhs: inout Matrix<Float>, rhs: Matrix<Float>) {
return addInPlace(&lhs, rhs)
}
public func += (lhs: inout Matrix<Double>, rhs: Matrix<Double>) {
return addInPlace(&lhs, rhs)
}
// MARK: - Subtraction
public func sub(_ lhs: Matrix<Float>, _ rhs: Matrix<Float>) -> Matrix<Float> {