From 1c274b3ee78efc413a439de5e9a65ad9b658571c Mon Sep 17 00:00:00 2001 From: Vincent Esche Date: Mon, 23 Sep 2019 13:45:42 +0200 Subject: [PATCH] Added `Matrix += Matrix` and turned `Matrix + Matrix` into a shallow wrapper around it --- Sources/Surge/Linear Algebra/Matrix.swift | 46 ++++++++++++++++++++--- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/Sources/Surge/Linear Algebra/Matrix.swift b/Sources/Surge/Linear Algebra/Matrix.swift index aa3bc61..ce7227e 100644 --- a/Sources/Surge/Linear Algebra/Matrix.swift +++ b/Sources/Surge/Linear Algebra/Matrix.swift @@ -269,12 +269,10 @@ public func == (lhs: Matrix, rhs: Matrix) -> Bool { public func add(_ lhs: Matrix, _ rhs: Matrix) -> Matrix { precondition(lhs.rows == rhs.rows && lhs.columns == rhs.columns, "Matrix dimensions not compatible with addition") - var results = rhs - results.grid.withUnsafeMutableBufferPointer { pointer in - cblas_saxpy(Int32(lhs.grid.count), 1.0, lhs.grid, 1, pointer.baseAddress!, 1) - } + var result = lhs + result += rhs - return results + return result } public func add(_ lhs: Matrix, _ rhs: Matrix) -> Matrix { @@ -296,6 +294,44 @@ public func + (lhs: Matrix, rhs: Matrix) -> Matrix { return add(lhs, rhs) } +// MARK: - Addition: In Place + +func addInPlace(_ lhs: inout Matrix, _ rhs: Matrix) { + precondition(lhs.rows == rhs.rows && lhs.columns == rhs.columns, "Matrix dimensions not compatible with addition") + + let gridSize = Int32(lhs.grid.count) + let alpha: Float = 1.0 + let stride: Int32 = 1 + + lhs.grid.withUnsafeMutableBufferPointer { lhsPointer in + rhs.grid.withUnsafeBufferPointer { rhsPointer in + cblas_saxpy(gridSize, alpha, rhsPointer.baseAddress!, stride, lhsPointer.baseAddress!, stride) + } + } +} + +func addInPlace(_ lhs: inout Matrix, _ rhs: Matrix) { + precondition(lhs.rows == rhs.rows && lhs.columns == rhs.columns, "Matrix dimensions not compatible with addition") + + let gridSize = Int32(lhs.grid.count) + let alpha: Double = 1.0 + let stride: Int32 = 1 + + lhs.grid.withUnsafeMutableBufferPointer { lhsPointer in + rhs.grid.withUnsafeBufferPointer { rhsPointer in + cblas_daxpy(gridSize, alpha, rhsPointer.baseAddress!, stride, lhsPointer.baseAddress!, stride) + } + } +} + +public func += (lhs: inout Matrix, rhs: Matrix) { + return addInPlace(&lhs, rhs) +} + +public func += (lhs: inout Matrix, rhs: Matrix) { + return addInPlace(&lhs, rhs) +} + // MARK: - Subtraction public func sub(_ lhs: Matrix, _ rhs: Matrix) -> Matrix {