From 20e63d07dc1cc689aa671b3561b5954c974927ee Mon Sep 17 00:00:00 2001 From: Saleem Abdulrasool Date: Mon, 19 Oct 2020 08:18:57 -0700 Subject: [PATCH] NIO: extract control message handling into a separate protocol (#1678) These are part of the BSD sockets APIs, but not directly related to the socket interfaces. Create an extension point to permit the different platforms to shuffle their implementation into place. This provides a nicer spelling for the functions and enables the codepaths on Windows as well. --- Sources/CNIOWindows/include/CNIOWindows.h | 11 +++++ Sources/CNIOWindows/shim.c | 24 ++++++++++ Sources/NIO/BSDSocketAPI.swift | 45 ++++++++++++++++-- Sources/NIO/BSDSocketAPIPosix.swift | 56 +++++++++++++++++++++++ Sources/NIO/BSDSocketAPIWindows.swift | 38 +++++++++++++++ Sources/NIO/ControlMessage.swift | 14 +++--- Sources/NIO/System.swift | 52 --------------------- Tests/NIOTests/SystemTest.swift | 12 ++--- 8 files changed, 183 insertions(+), 69 deletions(-) diff --git a/Sources/CNIOWindows/include/CNIOWindows.h b/Sources/CNIOWindows/include/CNIOWindows.h index 5a5cc171..3978b6de 100644 --- a/Sources/CNIOWindows/include/CNIOWindows.h +++ b/Sources/CNIOWindows/include/CNIOWindows.h @@ -19,6 +19,7 @@ #include #include +#include #define NIO(name) CNIOWindows_ ## name @@ -95,6 +96,16 @@ int NIO(sendmmsg)(SOCKET s, NIO(mmsghdr) *msgvec, unsigned int vlen, int flags); int NIO(recvmmsg)(SOCKET s, NIO(mmsghdr) *msgvec, unsigned int vlen, int flags, struct timespec *timeout); + +const void *NIO(CMSG_DATA)(const WSACMSGHDR *); +void *NIO(CMSG_DATA_MUTABLE)(LPWSACMSGHDR); + +WSACMSGHDR *NIO(CMSG_FIRSTHDR)(const WSAMSG *); +WSACMSGHDR *NIO(CMSG_NXTHDR)(const WSAMSG *, LPWSACMSGHDR); + +size_t NIO(CMSG_LEN)(size_t); +size_t NIO(CMSG_SPACE)(size_t); + #undef NIO #endif diff --git a/Sources/CNIOWindows/shim.c b/Sources/CNIOWindows/shim.c index fbb91686..b3c85240 100644 --- a/Sources/CNIOWindows/shim.c +++ b/Sources/CNIOWindows/shim.c @@ -31,4 +31,28 @@ int CNIOWindows_recvmmsg(SOCKET s, CNIOWindows_mmsghdr *msgvec, abort(); } +const void *CNIOWindows_CMSG_DATA(const WSACMSGHDR *pcmsg) { + return WSA_CMSG_DATA(pcmsg); +} + +void *CNIOWindows_CMSG_DATA_MUTABLE(LPWSACMSGHDR pcmsg) { + return WSA_CMSG_DATA(pcmsg); +} + +WSACMSGHDR *CNIOWindows_CMSG_FIRSTHDR(const WSAMSG *msg) { + return WSA_CMSG_FIRSTHDR(msg); +} + +WSACMSGHDR *CNIOWindows_CMSG_NXTHDR(const WSAMSG *msg, LPWSACMSGHDR cmsg) { + return WSA_CMSG_NXTHDR(msg, cmsg); +} + +size_t CNIOWindows_CMSG_LEN(size_t length) { + return WSA_CMSG_LEN(length); +} + +size_t CNIOWindows_CMSG_SPACE(size_t length) { + return WSA_CMSG_SPACE(length); +} + #endif diff --git a/Sources/NIO/BSDSocketAPI.swift b/Sources/NIO/BSDSocketAPI.swift index b16353f6..da2f0630 100644 --- a/Sources/NIO/BSDSocketAPI.swift +++ b/Sources/NIO/BSDSocketAPI.swift @@ -59,7 +59,13 @@ import let WinSDK.TCP_NODELAY import struct WinSDK.SOCKET +import struct WinSDK.WSACMSGHDR +import struct WinSDK.WSAMSG + import struct WinSDK.socklen_t + +internal typealias msghdr = WSAMSG +internal typealias cmsghdr = WSACMSGHDR #endif protocol _SocketShutdownProtocol { @@ -402,9 +408,11 @@ extension NIOBSDSocket.Option { /// The requested UDS path exists and has wrong type (not a socket). public struct UnixDomainSocketPathWrongType: Error {} -/// This protocol defines the methods that are expected to be found on `NIOBSDSocket`. While defined as a protocol -/// there is no expectation that any object other than `NIOBSDSocket` will implement this protocol: instead, this protocol -/// acts as a reference for what new supported operating systems must implement. +/// This protocol defines the methods that are expected to be found on +/// `NIOBSDSocket`. While defined as a protocol there is no expectation that any +/// object other than `NIOBSDSocket` will implement this protocol: instead, this +/// protocol acts as a reference for what new supported operating systems must +/// implement. protocol _BSDSocketProtocol { static func accept(socket s: NIOBSDSocket.Handle, address addr: UnsafeMutablePointer?, @@ -526,5 +534,34 @@ protocol _BSDSocketProtocol { static func cleanupUnixDomainSocket(atPath path: String) throws } -/// If this extension is hitting a compile error, your platform is missing one of the functions defined above! +/// If this extension is hitting a compile error, your platform is missing one +/// of the functions defined above! extension NIOBSDSocket: _BSDSocketProtocol { } + +/// This protocol defines the methods that are expected to be found on +/// `NIOBSDControlMessage`. While defined as a protocol there is no expectation +/// that any object other than `NIOBSDControlMessage` will implement this +/// protocol: instead, this protocol acts as a reference for what new supported +/// operating systems must implement. +protocol _BSDSocketControlMessageProtocol { + static func firstHeader(inside msghdr: UnsafePointer) + -> UnsafeMutablePointer? + + static func nextHeader(inside msghdr: UnsafeMutablePointer, + after: UnsafeMutablePointer) + -> UnsafeMutablePointer? + + static func data(for header: UnsafePointer) + -> UnsafeRawBufferPointer? + + static func data(for header: UnsafeMutablePointer) + -> UnsafeMutableRawBufferPointer? + + static func length(payloadSize: size_t) -> size_t + + static func space(payloadSize: size_t) -> size_t +} + +/// If this extension is hitting a compile error, your platform is missing one +/// of the functions defined above! +enum NIOBSDSocketControlMessage: _BSDSocketControlMessageProtocol { } diff --git a/Sources/NIO/BSDSocketAPIPosix.swift b/Sources/NIO/BSDSocketAPIPosix.swift index 5b44f3a0..25faee46 100644 --- a/Sources/NIO/BSDSocketAPIPosix.swift +++ b/Sources/NIO/BSDSocketAPIPosix.swift @@ -226,4 +226,60 @@ extension NIOBSDSocket { } } +#if os(iOS) || os(macOS) || os(tvOS) || os(watchOS) +import CNIODarwin +private let CMSG_FIRSTHDR = CNIODarwin_CMSG_FIRSTHDR +private let CMSG_NXTHDR = CNIODarwin_CMSG_NXTHDR +private let CMSG_DATA = CNIODarwin_CMSG_DATA +private let CMSG_DATA_MUTABLE = CNIODarwin_CMSG_DATA_MUTABLE +private let CMSG_SPACE = CNIODarwin_CMSG_SPACE +private let CMSG_LEN = CNIODarwin_CMSG_LEN +#else +import CNIOLinux +private let CMSG_FIRSTHDR = CNIOLinux_CMSG_FIRSTHDR +private let CMSG_NXTHDR = CNIOLinux_CMSG_NXTHDR +private let CMSG_DATA = CNIOLinux_CMSG_DATA +private let CMSG_DATA_MUTABLE = CNIOLinux_CMSG_DATA_MUTABLE +private let CMSG_SPACE = CNIOLinux_CMSG_SPACE +private let CMSG_LEN = CNIOLinux_CMSG_LEN +#endif + +// MARK: _BSDSocketControlMessageProtocol implementation +extension NIOBSDSocketControlMessage { + static func firstHeader(inside msghdr: UnsafePointer) + -> UnsafeMutablePointer? { + return CMSG_FIRSTHDR(msghdr) + } + + static func nextHeader(inside msghdr: UnsafeMutablePointer, + after: UnsafeMutablePointer) + -> UnsafeMutablePointer? { + return CMSG_NXTHDR(msghdr, after) + } + + static func data(for header: UnsafePointer) + -> UnsafeRawBufferPointer? { + let data = CMSG_DATA(header) + let length = + size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0) + return UnsafeRawBufferPointer(start: data, count: Int(length)) + } + + static func data(for header: UnsafeMutablePointer) + -> UnsafeMutableRawBufferPointer? { + let data = CMSG_DATA_MUTABLE(header) + let length = + size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0) + return UnsafeMutableRawBufferPointer(start: data, count: Int(length)) + } + + static func length(payloadSize: size_t) -> size_t { + return CMSG_LEN(payloadSize) + } + + static func space(payloadSize: size_t) -> size_t { + return CMSG_SPACE(payloadSize) + } +} + #endif diff --git a/Sources/NIO/BSDSocketAPIWindows.swift b/Sources/NIO/BSDSocketAPIWindows.swift index 4a1f229e..cac6c8b8 100644 --- a/Sources/NIO/BSDSocketAPIWindows.swift +++ b/Sources/NIO/BSDSocketAPIWindows.swift @@ -429,4 +429,42 @@ extension NIOBSDSocket { } } } + +// MARK: _BSDSocketControlMessageProtocol implementation +extension NIOBSDSocketControlMessage { + static func firstHeader(inside msghdr: UnsafePointer) + -> UnsafeMutablePointer? { + return CNIOWindows_CMSG_FIRSTHDR(msghdr) + } + + static func nextHeader(inside msghdr: UnsafeMutablePointer, + after: UnsafeMutablePointer) + -> UnsafeMutablePointer? { + return CNIOWindows_CMSG_NXTHDR(msghdr, after) + } + + static func data(for header: UnsafePointer) + -> UnsafeRawBufferPointer? { + let data = CNIOWindows_CMSG_DATA(header) + let length = + size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0) + return UnsafeRawBufferPointer(start: data, count: Int(length)) + } + + static func data(for header: UnsafeMutablePointer) + -> UnsafeMutableRawBufferPointer? { + let data = CNIOWindows_CMSG_DATA_MUTABLE(header) + let length = + size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0) + return UnsafeMutableRawBufferPointer(start: data, count: Int(length)) + } + + static func length(payloadSize: size_t) -> size_t { + return CNIOWindows_CMSG_LEN(payloadSize) + } + + static func space(payloadSize: size_t) -> size_t { + return CNIOWindows_CMSG_SPACE(payloadSize) + } +} #endif diff --git a/Sources/NIO/ControlMessage.swift b/Sources/NIO/ControlMessage.swift index 6ec87d56..efefb102 100644 --- a/Sources/NIO/ControlMessage.swift +++ b/Sources/NIO/ControlMessage.swift @@ -38,7 +38,7 @@ struct UnsafeControlMessageStorage: Collection { /// - msghdrCount: How many `msghdr` structures will be fed from this buffer - we assume 4 Int32 cmsgs for each. static func allocate(msghdrCount: Int) -> UnsafeControlMessageStorage { // Guess that 4 Int32 payload messages is enough for anyone. - let bytesPerMessage = Posix.cmsgSpace(payloadSize: MemoryLayout.stride) * 4 + let bytesPerMessage = NIOBSDSocketControlMessage.space(payloadSize: MemoryLayout.stride) * 4 let buffer = UnsafeMutableRawBufferPointer.allocate(byteCount: bytesPerMessage * msghdrCount, alignment: MemoryLayout.alignment) return UnsafeControlMessageStorage(bytesPerMessage: bytesPerMessage, buffer: buffer) @@ -111,7 +111,7 @@ extension UnsafeControlMessageCollection: Collection { var startIndex: Index { var messageHeader = self.messageHeader return withUnsafePointer(to: &messageHeader) { messageHeaderPtr in - let firstCMsg = Posix.cmsgFirstHeader(inside: messageHeaderPtr) + let firstCMsg = NIOBSDSocketControlMessage.firstHeader(inside: messageHeaderPtr) return Index(cmsgPointer: firstCMsg) } } @@ -121,7 +121,7 @@ extension UnsafeControlMessageCollection: Collection { func index(after: Index) -> Index { var msgHdr = messageHeader return withUnsafeMutablePointer(to: &msgHdr) { messageHeaderPtr in - return Index(cmsgPointer: Posix.cmsgNextHeader(inside: messageHeaderPtr, + return Index(cmsgPointer: NIOBSDSocketControlMessage.nextHeader(inside: messageHeaderPtr, after: after.cmsgPointer!)) } } @@ -130,7 +130,7 @@ extension UnsafeControlMessageCollection: Collection { let cmsg = position.cmsgPointer! return UnsafeControlMessage(level: cmsg.pointee.cmsg_level, type: cmsg.pointee.cmsg_type, - data: Posix.cmsgData(for: cmsg)) + data: NIOBSDSocketControlMessage.data(for: cmsg)) } } @@ -241,7 +241,7 @@ struct UnsafeOutboundControlBytes { payload: PayloadType) { let writableBuffer = UnsafeMutableRawBufferPointer(rebasing: self.controlBytes[writePosition...]) - let requiredSize = Posix.cmsgSpace(payloadSize: MemoryLayout.stride(ofValue: payload)) + let requiredSize = NIOBSDSocketControlMessage.space(payloadSize: MemoryLayout.stride(ofValue: payload)) precondition(writableBuffer.count >= requiredSize, "Insufficient size for cmsghdr and data") let bufferBase = writableBuffer.baseAddress! @@ -249,9 +249,9 @@ struct UnsafeOutboundControlBytes { let cmsghdrPtr = bufferBase.bindMemory(to: cmsghdr.self, capacity: 1) cmsghdrPtr.pointee.cmsg_level = level cmsghdrPtr.pointee.cmsg_type = type - cmsghdrPtr.pointee.cmsg_len = .init(Posix.cmsgLen(payloadSize: MemoryLayout.size(ofValue: payload))) + cmsghdrPtr.pointee.cmsg_len = .init(NIOBSDSocketControlMessage.length(payloadSize: MemoryLayout.size(ofValue: payload))) - let dataPointer = Posix.cmsgData(for: cmsghdrPtr)! + let dataPointer = NIOBSDSocketControlMessage.data(for: cmsghdrPtr)! precondition(dataPointer.count >= MemoryLayout.stride) dataPointer.storeBytes(of: payload, as: PayloadType.self) diff --git a/Sources/NIO/System.swift b/Sources/NIO/System.swift index fd83776a..1622ad54 100644 --- a/Sources/NIO/System.swift +++ b/Sources/NIO/System.swift @@ -106,15 +106,6 @@ private let sysStat: @convention(c) (UnsafePointer, UnsafeMutablePointer< private let sysUnlink: @convention(c) (UnsafePointer) -> CInt = unlink private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIOLinux_sendmmsg private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIOLinux_recvmmsg -private let sysCmsgFirstHdr: @convention(c) (UnsafePointer?) -> UnsafeMutablePointer? = - CNIOLinux_CMSG_FIRSTHDR -private let sysCmsgNxtHdr: @convention(c) (UnsafeMutablePointer?, UnsafeMutablePointer?) -> - UnsafeMutablePointer? = CNIOLinux_CMSG_NXTHDR -private let sysCmsgData: @convention(c) (UnsafePointer?) -> UnsafeRawPointer? = CNIOLinux_CMSG_DATA -private let sysCmsgDataMutable: @convention(c) (UnsafeMutablePointer?) -> UnsafeMutableRawPointer? = - CNIOLinux_CMSG_DATA_MUTABLE -private let sysCmsgSpace: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_SPACE -private let sysCmsgLen: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_LEN #elseif os(macOS) || os(iOS) || os(watchOS) || os(tvOS) private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer?) -> CInt = fstat private let sysStat: @convention(c) (UnsafePointer?, UnsafeMutablePointer?) -> CInt = stat @@ -122,16 +113,6 @@ private let sysUnlink: @convention(c) (UnsafePointer?) -> CInt = unlink private let sysKevent = kevent private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIODarwin_sendmmsg private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIODarwin_recvmmsg -private let sysCmsgFirstHdr: @convention(c) (UnsafePointer?) -> UnsafeMutablePointer? = - CNIODarwin_CMSG_FIRSTHDR -private let sysCmsgNxtHdr: @convention(c) (UnsafePointer?, UnsafePointer?) -> - UnsafeMutablePointer? = CNIODarwin_CMSG_NXTHDR -private let sysCmsgData: @convention(c) (UnsafePointer?) -> UnsafeRawPointer? = - CNIODarwin_CMSG_DATA -private let sysCmsgDataMutable: @convention(c) (UnsafeMutablePointer?) -> UnsafeMutableRawPointer? = - CNIODarwin_CMSG_DATA_MUTABLE -private let sysCmsgSpace: @convention(c) (size_t) -> size_t = CNIODarwin_CMSG_SPACE -private let sysCmsgLen: @convention(c) (size_t) -> size_t = CNIODarwin_CMSG_LEN #endif private func isUnacceptableErrno(_ code: Int32) -> Bool { @@ -556,39 +537,6 @@ internal enum Posix { sysSocketpair(domain.rawValue, type.rawValue, `protocol`, socketVector) } } - - static func cmsgFirstHeader(inside msghdr: UnsafePointer) -> UnsafeMutablePointer? { - return sysCmsgFirstHdr(msghdr) - } - - static func cmsgNextHeader(inside msghdr: UnsafeMutablePointer, - after: UnsafeMutablePointer) -> UnsafeMutablePointer? { - return sysCmsgNxtHdr(msghdr, after) - } - - static func cmsgData(for header: UnsafePointer) -> UnsafeRawBufferPointer? { - let dataPointer = sysCmsgData(header) - // Linux and Darwin use different types for cmsg_len. - let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0) - let buffer = UnsafeRawBufferPointer(start: dataPointer, count: Int(length)) - return buffer - } - - static func cmsgData(for header: UnsafeMutablePointer) -> UnsafeMutableRawBufferPointer? { - let dataPointer = sysCmsgDataMutable(header) - // Linux and Darwin use different types for cmsg_len. - let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0) - let buffer = UnsafeMutableRawBufferPointer(start: dataPointer, count: Int(length)) - return buffer - } - - static func cmsgLen(payloadSize: size_t) -> size_t { - return sysCmsgLen(payloadSize) - } - - static func cmsgSpace(payloadSize: size_t) -> size_t { - return sysCmsgSpace(payloadSize) - } #endif } diff --git a/Tests/NIOTests/SystemTest.swift b/Tests/NIOTests/SystemTest.swift index 3ff315f5..0a7fd85e 100644 --- a/Tests/NIOTests/SystemTest.swift +++ b/Tests/NIOTests/SystemTest.swift @@ -92,7 +92,7 @@ class SystemTest: XCTestCase { msgHdr.msg_controllen = .init(pCmsgHdr.count) withUnsafePointer(to: msgHdr) { pMsgHdr in - let result = Posix.cmsgFirstHeader(inside: pMsgHdr) + let result = NIOBSDSocketControlMessage.firstHeader(inside: pMsgHdr) XCTAssertEqual(pCmsgHdr.baseAddress, result) } } @@ -106,11 +106,11 @@ class SystemTest: XCTestCase { msgHdr.msg_controllen = .init(pCmsgHdr.count) withUnsafeMutablePointer(to: &msgHdr) { pMsgHdr in - let first = Posix.cmsgFirstHeader(inside: pMsgHdr) - let second = Posix.cmsgNextHeader(inside: pMsgHdr, after: first!) + let first = NIOBSDSocketControlMessage.firstHeader(inside: pMsgHdr) + let second = NIOBSDSocketControlMessage.nextHeader(inside: pMsgHdr, after: first!) let expectedSecondStart = pCmsgHdr.baseAddress! + SystemTest.cmsghdr_secondStartPosition XCTAssertEqual(expectedSecondStart, second!) - let third = Posix.cmsgNextHeader(inside: pMsgHdr, after: second!) + let third = NIOBSDSocketControlMessage.nextHeader(inside: pMsgHdr, after: second!) XCTAssertEqual(third, nil) } } @@ -124,8 +124,8 @@ class SystemTest: XCTestCase { msgHdr.msg_controllen = .init(pCmsgHdr.count) withUnsafePointer(to: msgHdr) { pMsgHdr in - let first = Posix.cmsgFirstHeader(inside: pMsgHdr) - let firstData = Posix.cmsgData(for: first!) + let first = NIOBSDSocketControlMessage.firstHeader(inside: pMsgHdr) + let firstData = NIOBSDSocketControlMessage.data(for: first!) let expecedFirstData = UnsafeRawBufferPointer( rebasing: pCmsgHdr[SystemTest.cmsghdr_firstDataStart..<( SystemTest.cmsghdr_firstDataStart + SystemTest.cmsghdr_firstDataCount)])