Added `Matrix += Matrix` and turned `Matrix + Matrix` into a shallow wrapper around it
This commit is contained in:
parent
d3f7a50ad9
commit
1c274b3ee7
|
@ -269,12 +269,10 @@ public func ==<T> (lhs: Matrix<T>, rhs: Matrix<T>) -> Bool {
|
|||
public func add(_ lhs: Matrix<Float>, _ rhs: Matrix<Float>) -> Matrix<Float> {
|
||||
precondition(lhs.rows == rhs.rows && lhs.columns == rhs.columns, "Matrix dimensions not compatible with addition")
|
||||
|
||||
var results = rhs
|
||||
results.grid.withUnsafeMutableBufferPointer { pointer in
|
||||
cblas_saxpy(Int32(lhs.grid.count), 1.0, lhs.grid, 1, pointer.baseAddress!, 1)
|
||||
}
|
||||
var result = lhs
|
||||
result += rhs
|
||||
|
||||
return results
|
||||
return result
|
||||
}
|
||||
|
||||
public func add(_ lhs: Matrix<Double>, _ rhs: Matrix<Double>) -> Matrix<Double> {
|
||||
|
@ -296,6 +294,44 @@ public func + (lhs: Matrix<Double>, rhs: Matrix<Double>) -> Matrix<Double> {
|
|||
return add(lhs, rhs)
|
||||
}
|
||||
|
||||
// MARK: - Addition: In Place
|
||||
|
||||
func addInPlace(_ lhs: inout Matrix<Float>, _ rhs: Matrix<Float>) {
|
||||
precondition(lhs.rows == rhs.rows && lhs.columns == rhs.columns, "Matrix dimensions not compatible with addition")
|
||||
|
||||
let gridSize = Int32(lhs.grid.count)
|
||||
let alpha: Float = 1.0
|
||||
let stride: Int32 = 1
|
||||
|
||||
lhs.grid.withUnsafeMutableBufferPointer { lhsPointer in
|
||||
rhs.grid.withUnsafeBufferPointer { rhsPointer in
|
||||
cblas_saxpy(gridSize, alpha, rhsPointer.baseAddress!, stride, lhsPointer.baseAddress!, stride)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addInPlace(_ lhs: inout Matrix<Double>, _ rhs: Matrix<Double>) {
|
||||
precondition(lhs.rows == rhs.rows && lhs.columns == rhs.columns, "Matrix dimensions not compatible with addition")
|
||||
|
||||
let gridSize = Int32(lhs.grid.count)
|
||||
let alpha: Double = 1.0
|
||||
let stride: Int32 = 1
|
||||
|
||||
lhs.grid.withUnsafeMutableBufferPointer { lhsPointer in
|
||||
rhs.grid.withUnsafeBufferPointer { rhsPointer in
|
||||
cblas_daxpy(gridSize, alpha, rhsPointer.baseAddress!, stride, lhsPointer.baseAddress!, stride)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func += (lhs: inout Matrix<Float>, rhs: Matrix<Float>) {
|
||||
return addInPlace(&lhs, rhs)
|
||||
}
|
||||
|
||||
public func += (lhs: inout Matrix<Double>, rhs: Matrix<Double>) {
|
||||
return addInPlace(&lhs, rhs)
|
||||
}
|
||||
|
||||
// MARK: - Subtraction
|
||||
|
||||
public func sub(_ lhs: Matrix<Float>, _ rhs: Matrix<Float>) -> Matrix<Float> {
|
||||
|
|
Loading…
Reference in New Issue