diff --git a/Sources/CNIOWindows/include/CNIOWindows.h b/Sources/CNIOWindows/include/CNIOWindows.h index 5a5cc171..3978b6de 100644 --- a/Sources/CNIOWindows/include/CNIOWindows.h +++ b/Sources/CNIOWindows/include/CNIOWindows.h @@ -19,6 +19,7 @@ #include #include +#include #define NIO(name) CNIOWindows_ ## name @@ -95,6 +96,16 @@ int NIO(sendmmsg)(SOCKET s, NIO(mmsghdr) *msgvec, unsigned int vlen, int flags); int NIO(recvmmsg)(SOCKET s, NIO(mmsghdr) *msgvec, unsigned int vlen, int flags, struct timespec *timeout); + +const void *NIO(CMSG_DATA)(const WSACMSGHDR *); +void *NIO(CMSG_DATA_MUTABLE)(LPWSACMSGHDR); + +WSACMSGHDR *NIO(CMSG_FIRSTHDR)(const WSAMSG *); +WSACMSGHDR *NIO(CMSG_NXTHDR)(const WSAMSG *, LPWSACMSGHDR); + +size_t NIO(CMSG_LEN)(size_t); +size_t NIO(CMSG_SPACE)(size_t); + #undef NIO #endif diff --git a/Sources/CNIOWindows/shim.c b/Sources/CNIOWindows/shim.c index fbb91686..b3c85240 100644 --- a/Sources/CNIOWindows/shim.c +++ b/Sources/CNIOWindows/shim.c @@ -31,4 +31,28 @@ int CNIOWindows_recvmmsg(SOCKET s, CNIOWindows_mmsghdr *msgvec, abort(); } +const void *CNIOWindows_CMSG_DATA(const WSACMSGHDR *pcmsg) { + return WSA_CMSG_DATA(pcmsg); +} + +void *CNIOWindows_CMSG_DATA_MUTABLE(LPWSACMSGHDR pcmsg) { + return WSA_CMSG_DATA(pcmsg); +} + +WSACMSGHDR *CNIOWindows_CMSG_FIRSTHDR(const WSAMSG *msg) { + return WSA_CMSG_FIRSTHDR(msg); +} + +WSACMSGHDR *CNIOWindows_CMSG_NXTHDR(const WSAMSG *msg, LPWSACMSGHDR cmsg) { + return WSA_CMSG_NXTHDR(msg, cmsg); +} + +size_t CNIOWindows_CMSG_LEN(size_t length) { + return WSA_CMSG_LEN(length); +} + +size_t CNIOWindows_CMSG_SPACE(size_t length) { + return WSA_CMSG_SPACE(length); +} + #endif diff --git a/Sources/NIO/BSDSocketAPI.swift b/Sources/NIO/BSDSocketAPI.swift index b16353f6..da2f0630 100644 --- a/Sources/NIO/BSDSocketAPI.swift +++ b/Sources/NIO/BSDSocketAPI.swift @@ -59,7 +59,13 @@ import let WinSDK.TCP_NODELAY import struct WinSDK.SOCKET +import struct WinSDK.WSACMSGHDR +import struct WinSDK.WSAMSG + import struct WinSDK.socklen_t + +internal typealias msghdr = WSAMSG +internal typealias cmsghdr = WSACMSGHDR #endif protocol _SocketShutdownProtocol { @@ -402,9 +408,11 @@ extension NIOBSDSocket.Option { /// The requested UDS path exists and has wrong type (not a socket). public struct UnixDomainSocketPathWrongType: Error {} -/// This protocol defines the methods that are expected to be found on `NIOBSDSocket`. While defined as a protocol -/// there is no expectation that any object other than `NIOBSDSocket` will implement this protocol: instead, this protocol -/// acts as a reference for what new supported operating systems must implement. +/// This protocol defines the methods that are expected to be found on +/// `NIOBSDSocket`. While defined as a protocol there is no expectation that any +/// object other than `NIOBSDSocket` will implement this protocol: instead, this +/// protocol acts as a reference for what new supported operating systems must +/// implement. protocol _BSDSocketProtocol { static func accept(socket s: NIOBSDSocket.Handle, address addr: UnsafeMutablePointer?, @@ -526,5 +534,34 @@ protocol _BSDSocketProtocol { static func cleanupUnixDomainSocket(atPath path: String) throws } -/// If this extension is hitting a compile error, your platform is missing one of the functions defined above! +/// If this extension is hitting a compile error, your platform is missing one +/// of the functions defined above! extension NIOBSDSocket: _BSDSocketProtocol { } + +/// This protocol defines the methods that are expected to be found on +/// `NIOBSDControlMessage`. While defined as a protocol there is no expectation +/// that any object other than `NIOBSDControlMessage` will implement this +/// protocol: instead, this protocol acts as a reference for what new supported +/// operating systems must implement. +protocol _BSDSocketControlMessageProtocol { + static func firstHeader(inside msghdr: UnsafePointer) + -> UnsafeMutablePointer? + + static func nextHeader(inside msghdr: UnsafeMutablePointer, + after: UnsafeMutablePointer) + -> UnsafeMutablePointer? + + static func data(for header: UnsafePointer) + -> UnsafeRawBufferPointer? + + static func data(for header: UnsafeMutablePointer) + -> UnsafeMutableRawBufferPointer? + + static func length(payloadSize: size_t) -> size_t + + static func space(payloadSize: size_t) -> size_t +} + +/// If this extension is hitting a compile error, your platform is missing one +/// of the functions defined above! +enum NIOBSDSocketControlMessage: _BSDSocketControlMessageProtocol { } diff --git a/Sources/NIO/BSDSocketAPIPosix.swift b/Sources/NIO/BSDSocketAPIPosix.swift index 5b44f3a0..25faee46 100644 --- a/Sources/NIO/BSDSocketAPIPosix.swift +++ b/Sources/NIO/BSDSocketAPIPosix.swift @@ -226,4 +226,60 @@ extension NIOBSDSocket { } } +#if os(iOS) || os(macOS) || os(tvOS) || os(watchOS) +import CNIODarwin +private let CMSG_FIRSTHDR = CNIODarwin_CMSG_FIRSTHDR +private let CMSG_NXTHDR = CNIODarwin_CMSG_NXTHDR +private let CMSG_DATA = CNIODarwin_CMSG_DATA +private let CMSG_DATA_MUTABLE = CNIODarwin_CMSG_DATA_MUTABLE +private let CMSG_SPACE = CNIODarwin_CMSG_SPACE +private let CMSG_LEN = CNIODarwin_CMSG_LEN +#else +import CNIOLinux +private let CMSG_FIRSTHDR = CNIOLinux_CMSG_FIRSTHDR +private let CMSG_NXTHDR = CNIOLinux_CMSG_NXTHDR +private let CMSG_DATA = CNIOLinux_CMSG_DATA +private let CMSG_DATA_MUTABLE = CNIOLinux_CMSG_DATA_MUTABLE +private let CMSG_SPACE = CNIOLinux_CMSG_SPACE +private let CMSG_LEN = CNIOLinux_CMSG_LEN +#endif + +// MARK: _BSDSocketControlMessageProtocol implementation +extension NIOBSDSocketControlMessage { + static func firstHeader(inside msghdr: UnsafePointer) + -> UnsafeMutablePointer? { + return CMSG_FIRSTHDR(msghdr) + } + + static func nextHeader(inside msghdr: UnsafeMutablePointer, + after: UnsafeMutablePointer) + -> UnsafeMutablePointer? { + return CMSG_NXTHDR(msghdr, after) + } + + static func data(for header: UnsafePointer) + -> UnsafeRawBufferPointer? { + let data = CMSG_DATA(header) + let length = + size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0) + return UnsafeRawBufferPointer(start: data, count: Int(length)) + } + + static func data(for header: UnsafeMutablePointer) + -> UnsafeMutableRawBufferPointer? { + let data = CMSG_DATA_MUTABLE(header) + let length = + size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0) + return UnsafeMutableRawBufferPointer(start: data, count: Int(length)) + } + + static func length(payloadSize: size_t) -> size_t { + return CMSG_LEN(payloadSize) + } + + static func space(payloadSize: size_t) -> size_t { + return CMSG_SPACE(payloadSize) + } +} + #endif diff --git a/Sources/NIO/BSDSocketAPIWindows.swift b/Sources/NIO/BSDSocketAPIWindows.swift index 4a1f229e..cac6c8b8 100644 --- a/Sources/NIO/BSDSocketAPIWindows.swift +++ b/Sources/NIO/BSDSocketAPIWindows.swift @@ -429,4 +429,42 @@ extension NIOBSDSocket { } } } + +// MARK: _BSDSocketControlMessageProtocol implementation +extension NIOBSDSocketControlMessage { + static func firstHeader(inside msghdr: UnsafePointer) + -> UnsafeMutablePointer? { + return CNIOWindows_CMSG_FIRSTHDR(msghdr) + } + + static func nextHeader(inside msghdr: UnsafeMutablePointer, + after: UnsafeMutablePointer) + -> UnsafeMutablePointer? { + return CNIOWindows_CMSG_NXTHDR(msghdr, after) + } + + static func data(for header: UnsafePointer) + -> UnsafeRawBufferPointer? { + let data = CNIOWindows_CMSG_DATA(header) + let length = + size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0) + return UnsafeRawBufferPointer(start: data, count: Int(length)) + } + + static func data(for header: UnsafeMutablePointer) + -> UnsafeMutableRawBufferPointer? { + let data = CNIOWindows_CMSG_DATA_MUTABLE(header) + let length = + size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0) + return UnsafeMutableRawBufferPointer(start: data, count: Int(length)) + } + + static func length(payloadSize: size_t) -> size_t { + return CNIOWindows_CMSG_LEN(payloadSize) + } + + static func space(payloadSize: size_t) -> size_t { + return CNIOWindows_CMSG_SPACE(payloadSize) + } +} #endif diff --git a/Sources/NIO/ControlMessage.swift b/Sources/NIO/ControlMessage.swift index 6ec87d56..efefb102 100644 --- a/Sources/NIO/ControlMessage.swift +++ b/Sources/NIO/ControlMessage.swift @@ -38,7 +38,7 @@ struct UnsafeControlMessageStorage: Collection { /// - msghdrCount: How many `msghdr` structures will be fed from this buffer - we assume 4 Int32 cmsgs for each. static func allocate(msghdrCount: Int) -> UnsafeControlMessageStorage { // Guess that 4 Int32 payload messages is enough for anyone. - let bytesPerMessage = Posix.cmsgSpace(payloadSize: MemoryLayout.stride) * 4 + let bytesPerMessage = NIOBSDSocketControlMessage.space(payloadSize: MemoryLayout.stride) * 4 let buffer = UnsafeMutableRawBufferPointer.allocate(byteCount: bytesPerMessage * msghdrCount, alignment: MemoryLayout.alignment) return UnsafeControlMessageStorage(bytesPerMessage: bytesPerMessage, buffer: buffer) @@ -111,7 +111,7 @@ extension UnsafeControlMessageCollection: Collection { var startIndex: Index { var messageHeader = self.messageHeader return withUnsafePointer(to: &messageHeader) { messageHeaderPtr in - let firstCMsg = Posix.cmsgFirstHeader(inside: messageHeaderPtr) + let firstCMsg = NIOBSDSocketControlMessage.firstHeader(inside: messageHeaderPtr) return Index(cmsgPointer: firstCMsg) } } @@ -121,7 +121,7 @@ extension UnsafeControlMessageCollection: Collection { func index(after: Index) -> Index { var msgHdr = messageHeader return withUnsafeMutablePointer(to: &msgHdr) { messageHeaderPtr in - return Index(cmsgPointer: Posix.cmsgNextHeader(inside: messageHeaderPtr, + return Index(cmsgPointer: NIOBSDSocketControlMessage.nextHeader(inside: messageHeaderPtr, after: after.cmsgPointer!)) } } @@ -130,7 +130,7 @@ extension UnsafeControlMessageCollection: Collection { let cmsg = position.cmsgPointer! return UnsafeControlMessage(level: cmsg.pointee.cmsg_level, type: cmsg.pointee.cmsg_type, - data: Posix.cmsgData(for: cmsg)) + data: NIOBSDSocketControlMessage.data(for: cmsg)) } } @@ -241,7 +241,7 @@ struct UnsafeOutboundControlBytes { payload: PayloadType) { let writableBuffer = UnsafeMutableRawBufferPointer(rebasing: self.controlBytes[writePosition...]) - let requiredSize = Posix.cmsgSpace(payloadSize: MemoryLayout.stride(ofValue: payload)) + let requiredSize = NIOBSDSocketControlMessage.space(payloadSize: MemoryLayout.stride(ofValue: payload)) precondition(writableBuffer.count >= requiredSize, "Insufficient size for cmsghdr and data") let bufferBase = writableBuffer.baseAddress! @@ -249,9 +249,9 @@ struct UnsafeOutboundControlBytes { let cmsghdrPtr = bufferBase.bindMemory(to: cmsghdr.self, capacity: 1) cmsghdrPtr.pointee.cmsg_level = level cmsghdrPtr.pointee.cmsg_type = type - cmsghdrPtr.pointee.cmsg_len = .init(Posix.cmsgLen(payloadSize: MemoryLayout.size(ofValue: payload))) + cmsghdrPtr.pointee.cmsg_len = .init(NIOBSDSocketControlMessage.length(payloadSize: MemoryLayout.size(ofValue: payload))) - let dataPointer = Posix.cmsgData(for: cmsghdrPtr)! + let dataPointer = NIOBSDSocketControlMessage.data(for: cmsghdrPtr)! precondition(dataPointer.count >= MemoryLayout.stride) dataPointer.storeBytes(of: payload, as: PayloadType.self) diff --git a/Sources/NIO/System.swift b/Sources/NIO/System.swift index fd83776a..1622ad54 100644 --- a/Sources/NIO/System.swift +++ b/Sources/NIO/System.swift @@ -106,15 +106,6 @@ private let sysStat: @convention(c) (UnsafePointer, UnsafeMutablePointer< private let sysUnlink: @convention(c) (UnsafePointer) -> CInt = unlink private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIOLinux_sendmmsg private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIOLinux_recvmmsg -private let sysCmsgFirstHdr: @convention(c) (UnsafePointer?) -> UnsafeMutablePointer? = - CNIOLinux_CMSG_FIRSTHDR -private let sysCmsgNxtHdr: @convention(c) (UnsafeMutablePointer?, UnsafeMutablePointer?) -> - UnsafeMutablePointer? = CNIOLinux_CMSG_NXTHDR -private let sysCmsgData: @convention(c) (UnsafePointer?) -> UnsafeRawPointer? = CNIOLinux_CMSG_DATA -private let sysCmsgDataMutable: @convention(c) (UnsafeMutablePointer?) -> UnsafeMutableRawPointer? = - CNIOLinux_CMSG_DATA_MUTABLE -private let sysCmsgSpace: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_SPACE -private let sysCmsgLen: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_LEN #elseif os(macOS) || os(iOS) || os(watchOS) || os(tvOS) private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer?) -> CInt = fstat private let sysStat: @convention(c) (UnsafePointer?, UnsafeMutablePointer?) -> CInt = stat @@ -122,16 +113,6 @@ private let sysUnlink: @convention(c) (UnsafePointer?) -> CInt = unlink private let sysKevent = kevent private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIODarwin_sendmmsg private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIODarwin_recvmmsg -private let sysCmsgFirstHdr: @convention(c) (UnsafePointer?) -> UnsafeMutablePointer? = - CNIODarwin_CMSG_FIRSTHDR -private let sysCmsgNxtHdr: @convention(c) (UnsafePointer?, UnsafePointer?) -> - UnsafeMutablePointer? = CNIODarwin_CMSG_NXTHDR -private let sysCmsgData: @convention(c) (UnsafePointer?) -> UnsafeRawPointer? = - CNIODarwin_CMSG_DATA -private let sysCmsgDataMutable: @convention(c) (UnsafeMutablePointer?) -> UnsafeMutableRawPointer? = - CNIODarwin_CMSG_DATA_MUTABLE -private let sysCmsgSpace: @convention(c) (size_t) -> size_t = CNIODarwin_CMSG_SPACE -private let sysCmsgLen: @convention(c) (size_t) -> size_t = CNIODarwin_CMSG_LEN #endif private func isUnacceptableErrno(_ code: Int32) -> Bool { @@ -556,39 +537,6 @@ internal enum Posix { sysSocketpair(domain.rawValue, type.rawValue, `protocol`, socketVector) } } - - static func cmsgFirstHeader(inside msghdr: UnsafePointer) -> UnsafeMutablePointer? { - return sysCmsgFirstHdr(msghdr) - } - - static func cmsgNextHeader(inside msghdr: UnsafeMutablePointer, - after: UnsafeMutablePointer) -> UnsafeMutablePointer? { - return sysCmsgNxtHdr(msghdr, after) - } - - static func cmsgData(for header: UnsafePointer) -> UnsafeRawBufferPointer? { - let dataPointer = sysCmsgData(header) - // Linux and Darwin use different types for cmsg_len. - let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0) - let buffer = UnsafeRawBufferPointer(start: dataPointer, count: Int(length)) - return buffer - } - - static func cmsgData(for header: UnsafeMutablePointer) -> UnsafeMutableRawBufferPointer? { - let dataPointer = sysCmsgDataMutable(header) - // Linux and Darwin use different types for cmsg_len. - let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0) - let buffer = UnsafeMutableRawBufferPointer(start: dataPointer, count: Int(length)) - return buffer - } - - static func cmsgLen(payloadSize: size_t) -> size_t { - return sysCmsgLen(payloadSize) - } - - static func cmsgSpace(payloadSize: size_t) -> size_t { - return sysCmsgSpace(payloadSize) - } #endif } diff --git a/Tests/NIOTests/SystemTest.swift b/Tests/NIOTests/SystemTest.swift index 3ff315f5..0a7fd85e 100644 --- a/Tests/NIOTests/SystemTest.swift +++ b/Tests/NIOTests/SystemTest.swift @@ -92,7 +92,7 @@ class SystemTest: XCTestCase { msgHdr.msg_controllen = .init(pCmsgHdr.count) withUnsafePointer(to: msgHdr) { pMsgHdr in - let result = Posix.cmsgFirstHeader(inside: pMsgHdr) + let result = NIOBSDSocketControlMessage.firstHeader(inside: pMsgHdr) XCTAssertEqual(pCmsgHdr.baseAddress, result) } } @@ -106,11 +106,11 @@ class SystemTest: XCTestCase { msgHdr.msg_controllen = .init(pCmsgHdr.count) withUnsafeMutablePointer(to: &msgHdr) { pMsgHdr in - let first = Posix.cmsgFirstHeader(inside: pMsgHdr) - let second = Posix.cmsgNextHeader(inside: pMsgHdr, after: first!) + let first = NIOBSDSocketControlMessage.firstHeader(inside: pMsgHdr) + let second = NIOBSDSocketControlMessage.nextHeader(inside: pMsgHdr, after: first!) let expectedSecondStart = pCmsgHdr.baseAddress! + SystemTest.cmsghdr_secondStartPosition XCTAssertEqual(expectedSecondStart, second!) - let third = Posix.cmsgNextHeader(inside: pMsgHdr, after: second!) + let third = NIOBSDSocketControlMessage.nextHeader(inside: pMsgHdr, after: second!) XCTAssertEqual(third, nil) } } @@ -124,8 +124,8 @@ class SystemTest: XCTestCase { msgHdr.msg_controllen = .init(pCmsgHdr.count) withUnsafePointer(to: msgHdr) { pMsgHdr in - let first = Posix.cmsgFirstHeader(inside: pMsgHdr) - let firstData = Posix.cmsgData(for: first!) + let first = NIOBSDSocketControlMessage.firstHeader(inside: pMsgHdr) + let firstData = NIOBSDSocketControlMessage.data(for: first!) let expecedFirstData = UnsafeRawBufferPointer( rebasing: pCmsgHdr[SystemTest.cmsghdr_firstDataStart..<( SystemTest.cmsghdr_firstDataStart + SystemTest.cmsghdr_firstDataCount)])