diff --git a/Sources/NIO/SocketAddresses.swift b/Sources/NIO/SocketAddresses.swift index 74182867..42b972aa 100644 --- a/Sources/NIO/SocketAddresses.swift +++ b/Sources/NIO/SocketAddresses.swift @@ -258,3 +258,37 @@ public enum SocketAddress: CustomStringConvertible { } } +/// We define an extension on `SocketAddress` that gives it an elementwise equatable conformance, using +/// only the elements defined on the structure in their man pages (excluding lengths). +extension SocketAddress: Equatable { + public static func ==(lhs: SocketAddress, rhs: SocketAddress) -> Bool { + switch (lhs, rhs) { + case (.v4(let addr1), .v4(let addr2)): + return addr1.address.sin_family == addr2.address.sin_family && + addr1.address.sin_port == addr2.address.sin_port && + addr1.address.sin_addr.s_addr == addr2.address.sin_addr.s_addr + case (.v6(let addr1), .v6(let addr2)): + guard addr1.address.sin6_family == addr2.address.sin6_family && + addr1.address.sin6_port == addr2.address.sin6_port && + addr1.address.sin6_flowinfo == addr2.address.sin6_flowinfo && + addr1.address.sin6_scope_id == addr2.address.sin6_scope_id else { + return false + } + + var s6addr1 = addr1.address.sin6_addr + var s6addr2 = addr2.address.sin6_addr + return memcmp(&s6addr1, &s6addr2, MemoryLayout.size(ofValue: s6addr1)) == 0 + case (.unixDomainSocket(let addr1), .unixDomainSocket(let addr2)): + guard addr1.address.sun_family == addr2.address.sun_family else { + return false + } + + var sunpath1 = addr1.address.sun_path + var sunpath2 = addr2.address.sun_path + return memcmp(&sunpath1, &sunpath2, MemoryLayout.size(ofValue: sunpath1)) == 0 + default: + return false + } + } +} + diff --git a/Tests/NIOTests/SocketAddressTest+XCTest.swift b/Tests/NIOTests/SocketAddressTest+XCTest.swift index be847fe2..54923735 100644 --- a/Tests/NIOTests/SocketAddressTest+XCTest.swift +++ b/Tests/NIOTests/SocketAddressTest+XCTest.swift @@ -34,6 +34,10 @@ extension SocketAddressTest { ("testWithMutableAddressAllowsMutationWithoutPersistence", testWithMutableAddressAllowsMutationWithoutPersistence), ("testConvertingStorage", testConvertingStorage), ("testComparingSockaddrs", testComparingSockaddrs), + ("testEqualSocketAddresses", testEqualSocketAddresses), + ("testUnequalAddressesOnPort", testUnequalAddressesOnPort), + ("testUnequalOnAddress", testUnequalOnAddress), + ("testUnequalAcrossFamilies", testUnequalAcrossFamilies), ] } } diff --git a/Tests/NIOTests/SocketAddressTest.swift b/Tests/NIOTests/SocketAddressTest.swift index 3f4b13a0..a87e4a2d 100644 --- a/Tests/NIOTests/SocketAddressTest.swift +++ b/Tests/NIOTests/SocketAddressTest.swift @@ -259,4 +259,51 @@ class SocketAddressTest: XCTestCase { } } } + + func testEqualSocketAddresses() throws { + let first = try SocketAddress.ipAddress(string: "::1", port: 80) + let second = try SocketAddress.ipAddress(string: "00:00::1", port: 80) + let third = try SocketAddress.ipAddress(string: "127.0.0.1", port: 443) + let fourth = try SocketAddress.ipAddress(string: "127.0.0.1", port: 443) + let fifth = try SocketAddress.unixDomainSocketAddress(path: "/var/tmp") + let sixth = try SocketAddress.unixDomainSocketAddress(path: "/var/tmp") + + XCTAssertEqual(first, second) + XCTAssertEqual(third, fourth) + XCTAssertEqual(fifth, sixth) + } + + func testUnequalAddressesOnPort() throws { + let first = try SocketAddress.ipAddress(string: "::1", port: 80) + let second = try SocketAddress.ipAddress(string: "::1", port: 443) + let third = try SocketAddress.ipAddress(string: "127.0.0.1", port: 80) + let fourth = try SocketAddress.ipAddress(string: "127.0.0.1", port: 443) + + XCTAssertNotEqual(first, second) + XCTAssertNotEqual(third, fourth) + } + + func testUnequalOnAddress() throws { + let first = try SocketAddress.ipAddress(string: "::1", port: 80) + let second = try SocketAddress.ipAddress(string: "::2", port: 80) + let third = try SocketAddress.ipAddress(string: "127.0.0.1", port: 443) + let fourth = try SocketAddress.ipAddress(string: "127.0.0.2", port: 443) + let fifth = try SocketAddress.unixDomainSocketAddress(path: "/var/tmp") + let sixth = try SocketAddress.unixDomainSocketAddress(path: "/var/tmq") + + XCTAssertNotEqual(first, second) + XCTAssertNotEqual(third, fourth) + XCTAssertNotEqual(fifth, sixth) + } + + func testUnequalAcrossFamilies() throws { + let first = try SocketAddress.ipAddress(string: "::1", port: 80) + let second = try SocketAddress.ipAddress(string: "127.0.0.1", port: 80) + let third = try SocketAddress.unixDomainSocketAddress(path: "/var/tmp") + + XCTAssertNotEqual(first, second) + XCTAssertNotEqual(second, third) + // By the transitive property first != third, but let's protect against me being an idiot + XCTAssertNotEqual(third, first) + } }