Added `sq` & `sqInPlace` to ‘Arithmetic.swift’

This commit is contained in:
Vincent Esche 2019-09-24 16:20:04 +02:00
parent e4c37fb282
commit 1a198b0396
2 changed files with 62 additions and 0 deletions

View File

@ -700,6 +700,36 @@ func powInPlace<X: UnsafeMutableMemoryAccessible>(_ lhs: inout X, _ rhs: Double)
return powInPlace(&lhs, rhs)
}
//// MARK: - Square
public func sq<L: UnsafeMemoryAccessible>(_ lhs: L) -> [Float] where L.Element == Float {
var results = Array(lhs)
sqInPlace(&results)
return results
}
public func sq<L: UnsafeMemoryAccessible>(_ lhs: L) -> [Double] where L.Element == Double {
var results = Array(lhs)
sqInPlace(&results)
return results
}
// MARK: - Square: In Place
func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Float {
let elementCount: vDSP_Length = numericCast(lhs.count)
withUnsafeMutableMemory(&lhs) { lm in
vDSP_vsq(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), elementCount)
}
}
public func sqInPlace<L: UnsafeMutableMemoryAccessible>(_ lhs: inout L) where L.Element == Double {
let elementCount: vDSP_Length = numericCast(lhs.count)
withUnsafeMutableMemory(&lhs) { lm in
vDSP_vsqD(lm.pointer, numericCast(lm.stride), lm.pointer, numericCast(lm.stride), elementCount)
}
}
// MARK: - Square Root
/// Elemen-wise square root.

View File

@ -529,6 +529,38 @@ class ArithmeticTests: XCTestCase {
XCTAssertEqual(actual, expected, accuracy: 1e-8)
}
// MARK: - Square
func test_sq_array_float() {
typealias Scalar = Float
let lhs: [Scalar] = (1...n).map { Scalar($0) / Scalar(n) }
var actual: [Scalar] = []
measure {
actual = Surge.sq(lhs)
}
let expected = lhs.map { pow($0, 2.0) }
XCTAssertEqual(actual, expected, accuracy: 1e-5)
}
func test_sq_array_double() {
typealias Scalar = Double
let lhs: [Scalar] = (1...n).map { Scalar($0) / Scalar(n) }
var actual: [Scalar] = []
measure {
actual = Surge.sq(lhs)
}
let expected = lhs.map { pow($0, 2.0) }
XCTAssertEqual(actual, expected, accuracy: 1e-8)
}
// MARK: - Square Root
func test_sqrt_array_array_float() {