NIO: extract control message handling into a separate protocol (#1678)

These are part of the BSD sockets APIs, but not directly related to the
socket interfaces.  Create an extension point to permit the different
platforms to shuffle their implementation into place.  This provides a
nicer spelling for the functions and enables the codepaths on Windows as
well.
This commit is contained in:
Saleem Abdulrasool 2020-10-19 08:18:57 -07:00 committed by GitHub
parent 1cdbf763fb
commit 20e63d07dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 183 additions and 69 deletions

View File

@ -19,6 +19,7 @@
#include <WinSock2.h>
#include <time.h>
#include <stdint.h>
#define NIO(name) CNIOWindows_ ## name
@ -95,6 +96,16 @@ int NIO(sendmmsg)(SOCKET s, NIO(mmsghdr) *msgvec, unsigned int vlen, int flags);
int NIO(recvmmsg)(SOCKET s, NIO(mmsghdr) *msgvec, unsigned int vlen, int flags,
struct timespec *timeout);
const void *NIO(CMSG_DATA)(const WSACMSGHDR *);
void *NIO(CMSG_DATA_MUTABLE)(LPWSACMSGHDR);
WSACMSGHDR *NIO(CMSG_FIRSTHDR)(const WSAMSG *);
WSACMSGHDR *NIO(CMSG_NXTHDR)(const WSAMSG *, LPWSACMSGHDR);
size_t NIO(CMSG_LEN)(size_t);
size_t NIO(CMSG_SPACE)(size_t);
#undef NIO
#endif

View File

@ -31,4 +31,28 @@ int CNIOWindows_recvmmsg(SOCKET s, CNIOWindows_mmsghdr *msgvec,
abort();
}
const void *CNIOWindows_CMSG_DATA(const WSACMSGHDR *pcmsg) {
return WSA_CMSG_DATA(pcmsg);
}
void *CNIOWindows_CMSG_DATA_MUTABLE(LPWSACMSGHDR pcmsg) {
return WSA_CMSG_DATA(pcmsg);
}
WSACMSGHDR *CNIOWindows_CMSG_FIRSTHDR(const WSAMSG *msg) {
return WSA_CMSG_FIRSTHDR(msg);
}
WSACMSGHDR *CNIOWindows_CMSG_NXTHDR(const WSAMSG *msg, LPWSACMSGHDR cmsg) {
return WSA_CMSG_NXTHDR(msg, cmsg);
}
size_t CNIOWindows_CMSG_LEN(size_t length) {
return WSA_CMSG_LEN(length);
}
size_t CNIOWindows_CMSG_SPACE(size_t length) {
return WSA_CMSG_SPACE(length);
}
#endif

View File

@ -59,7 +59,13 @@ import let WinSDK.TCP_NODELAY
import struct WinSDK.SOCKET
import struct WinSDK.WSACMSGHDR
import struct WinSDK.WSAMSG
import struct WinSDK.socklen_t
internal typealias msghdr = WSAMSG
internal typealias cmsghdr = WSACMSGHDR
#endif
protocol _SocketShutdownProtocol {
@ -402,9 +408,11 @@ extension NIOBSDSocket.Option {
/// The requested UDS path exists and has wrong type (not a socket).
public struct UnixDomainSocketPathWrongType: Error {}
/// This protocol defines the methods that are expected to be found on `NIOBSDSocket`. While defined as a protocol
/// there is no expectation that any object other than `NIOBSDSocket` will implement this protocol: instead, this protocol
/// acts as a reference for what new supported operating systems must implement.
/// This protocol defines the methods that are expected to be found on
/// `NIOBSDSocket`. While defined as a protocol there is no expectation that any
/// object other than `NIOBSDSocket` will implement this protocol: instead, this
/// protocol acts as a reference for what new supported operating systems must
/// implement.
protocol _BSDSocketProtocol {
static func accept(socket s: NIOBSDSocket.Handle,
address addr: UnsafeMutablePointer<sockaddr>?,
@ -526,5 +534,34 @@ protocol _BSDSocketProtocol {
static func cleanupUnixDomainSocket(atPath path: String) throws
}
/// If this extension is hitting a compile error, your platform is missing one of the functions defined above!
/// If this extension is hitting a compile error, your platform is missing one
/// of the functions defined above!
extension NIOBSDSocket: _BSDSocketProtocol { }
/// This protocol defines the methods that are expected to be found on
/// `NIOBSDControlMessage`. While defined as a protocol there is no expectation
/// that any object other than `NIOBSDControlMessage` will implement this
/// protocol: instead, this protocol acts as a reference for what new supported
/// operating systems must implement.
protocol _BSDSocketControlMessageProtocol {
static func firstHeader(inside msghdr: UnsafePointer<msghdr>)
-> UnsafeMutablePointer<cmsghdr>?
static func nextHeader(inside msghdr: UnsafeMutablePointer<msghdr>,
after: UnsafeMutablePointer<cmsghdr>)
-> UnsafeMutablePointer<cmsghdr>?
static func data(for header: UnsafePointer<cmsghdr>)
-> UnsafeRawBufferPointer?
static func data(for header: UnsafeMutablePointer<cmsghdr>)
-> UnsafeMutableRawBufferPointer?
static func length(payloadSize: size_t) -> size_t
static func space(payloadSize: size_t) -> size_t
}
/// If this extension is hitting a compile error, your platform is missing one
/// of the functions defined above!
enum NIOBSDSocketControlMessage: _BSDSocketControlMessageProtocol { }

View File

@ -226,4 +226,60 @@ extension NIOBSDSocket {
}
}
#if os(iOS) || os(macOS) || os(tvOS) || os(watchOS)
import CNIODarwin
private let CMSG_FIRSTHDR = CNIODarwin_CMSG_FIRSTHDR
private let CMSG_NXTHDR = CNIODarwin_CMSG_NXTHDR
private let CMSG_DATA = CNIODarwin_CMSG_DATA
private let CMSG_DATA_MUTABLE = CNIODarwin_CMSG_DATA_MUTABLE
private let CMSG_SPACE = CNIODarwin_CMSG_SPACE
private let CMSG_LEN = CNIODarwin_CMSG_LEN
#else
import CNIOLinux
private let CMSG_FIRSTHDR = CNIOLinux_CMSG_FIRSTHDR
private let CMSG_NXTHDR = CNIOLinux_CMSG_NXTHDR
private let CMSG_DATA = CNIOLinux_CMSG_DATA
private let CMSG_DATA_MUTABLE = CNIOLinux_CMSG_DATA_MUTABLE
private let CMSG_SPACE = CNIOLinux_CMSG_SPACE
private let CMSG_LEN = CNIOLinux_CMSG_LEN
#endif
// MARK: _BSDSocketControlMessageProtocol implementation
extension NIOBSDSocketControlMessage {
static func firstHeader(inside msghdr: UnsafePointer<msghdr>)
-> UnsafeMutablePointer<cmsghdr>? {
return CMSG_FIRSTHDR(msghdr)
}
static func nextHeader(inside msghdr: UnsafeMutablePointer<msghdr>,
after: UnsafeMutablePointer<cmsghdr>)
-> UnsafeMutablePointer<cmsghdr>? {
return CMSG_NXTHDR(msghdr, after)
}
static func data(for header: UnsafePointer<cmsghdr>)
-> UnsafeRawBufferPointer? {
let data = CMSG_DATA(header)
let length =
size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0)
return UnsafeRawBufferPointer(start: data, count: Int(length))
}
static func data(for header: UnsafeMutablePointer<cmsghdr>)
-> UnsafeMutableRawBufferPointer? {
let data = CMSG_DATA_MUTABLE(header)
let length =
size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0)
return UnsafeMutableRawBufferPointer(start: data, count: Int(length))
}
static func length(payloadSize: size_t) -> size_t {
return CMSG_LEN(payloadSize)
}
static func space(payloadSize: size_t) -> size_t {
return CMSG_SPACE(payloadSize)
}
}
#endif

View File

@ -429,4 +429,42 @@ extension NIOBSDSocket {
}
}
}
// MARK: _BSDSocketControlMessageProtocol implementation
extension NIOBSDSocketControlMessage {
static func firstHeader(inside msghdr: UnsafePointer<msghdr>)
-> UnsafeMutablePointer<cmsghdr>? {
return CNIOWindows_CMSG_FIRSTHDR(msghdr)
}
static func nextHeader(inside msghdr: UnsafeMutablePointer<msghdr>,
after: UnsafeMutablePointer<cmsghdr>)
-> UnsafeMutablePointer<cmsghdr>? {
return CNIOWindows_CMSG_NXTHDR(msghdr, after)
}
static func data(for header: UnsafePointer<cmsghdr>)
-> UnsafeRawBufferPointer? {
let data = CNIOWindows_CMSG_DATA(header)
let length =
size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0)
return UnsafeRawBufferPointer(start: data, count: Int(length))
}
static func data(for header: UnsafeMutablePointer<cmsghdr>)
-> UnsafeMutableRawBufferPointer? {
let data = CNIOWindows_CMSG_DATA_MUTABLE(header)
let length =
size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0)
return UnsafeMutableRawBufferPointer(start: data, count: Int(length))
}
static func length(payloadSize: size_t) -> size_t {
return CNIOWindows_CMSG_LEN(payloadSize)
}
static func space(payloadSize: size_t) -> size_t {
return CNIOWindows_CMSG_SPACE(payloadSize)
}
}
#endif

View File

@ -38,7 +38,7 @@ struct UnsafeControlMessageStorage: Collection {
/// - msghdrCount: How many `msghdr` structures will be fed from this buffer - we assume 4 Int32 cmsgs for each.
static func allocate(msghdrCount: Int) -> UnsafeControlMessageStorage {
// Guess that 4 Int32 payload messages is enough for anyone.
let bytesPerMessage = Posix.cmsgSpace(payloadSize: MemoryLayout<Int32>.stride) * 4
let bytesPerMessage = NIOBSDSocketControlMessage.space(payloadSize: MemoryLayout<Int32>.stride) * 4
let buffer = UnsafeMutableRawBufferPointer.allocate(byteCount: bytesPerMessage * msghdrCount,
alignment: MemoryLayout<cmsghdr>.alignment)
return UnsafeControlMessageStorage(bytesPerMessage: bytesPerMessage, buffer: buffer)
@ -111,7 +111,7 @@ extension UnsafeControlMessageCollection: Collection {
var startIndex: Index {
var messageHeader = self.messageHeader
return withUnsafePointer(to: &messageHeader) { messageHeaderPtr in
let firstCMsg = Posix.cmsgFirstHeader(inside: messageHeaderPtr)
let firstCMsg = NIOBSDSocketControlMessage.firstHeader(inside: messageHeaderPtr)
return Index(cmsgPointer: firstCMsg)
}
}
@ -121,7 +121,7 @@ extension UnsafeControlMessageCollection: Collection {
func index(after: Index) -> Index {
var msgHdr = messageHeader
return withUnsafeMutablePointer(to: &msgHdr) { messageHeaderPtr in
return Index(cmsgPointer: Posix.cmsgNextHeader(inside: messageHeaderPtr,
return Index(cmsgPointer: NIOBSDSocketControlMessage.nextHeader(inside: messageHeaderPtr,
after: after.cmsgPointer!))
}
}
@ -130,7 +130,7 @@ extension UnsafeControlMessageCollection: Collection {
let cmsg = position.cmsgPointer!
return UnsafeControlMessage(level: cmsg.pointee.cmsg_level,
type: cmsg.pointee.cmsg_type,
data: Posix.cmsgData(for: cmsg))
data: NIOBSDSocketControlMessage.data(for: cmsg))
}
}
@ -241,7 +241,7 @@ struct UnsafeOutboundControlBytes {
payload: PayloadType) {
let writableBuffer = UnsafeMutableRawBufferPointer(rebasing: self.controlBytes[writePosition...])
let requiredSize = Posix.cmsgSpace(payloadSize: MemoryLayout.stride(ofValue: payload))
let requiredSize = NIOBSDSocketControlMessage.space(payloadSize: MemoryLayout.stride(ofValue: payload))
precondition(writableBuffer.count >= requiredSize, "Insufficient size for cmsghdr and data")
let bufferBase = writableBuffer.baseAddress!
@ -249,9 +249,9 @@ struct UnsafeOutboundControlBytes {
let cmsghdrPtr = bufferBase.bindMemory(to: cmsghdr.self, capacity: 1)
cmsghdrPtr.pointee.cmsg_level = level
cmsghdrPtr.pointee.cmsg_type = type
cmsghdrPtr.pointee.cmsg_len = .init(Posix.cmsgLen(payloadSize: MemoryLayout.size(ofValue: payload)))
cmsghdrPtr.pointee.cmsg_len = .init(NIOBSDSocketControlMessage.length(payloadSize: MemoryLayout.size(ofValue: payload)))
let dataPointer = Posix.cmsgData(for: cmsghdrPtr)!
let dataPointer = NIOBSDSocketControlMessage.data(for: cmsghdrPtr)!
precondition(dataPointer.count >= MemoryLayout<PayloadType>.stride)
dataPointer.storeBytes(of: payload, as: PayloadType.self)

View File

@ -106,15 +106,6 @@ private let sysStat: @convention(c) (UnsafePointer<CChar>, UnsafeMutablePointer<
private let sysUnlink: @convention(c) (UnsafePointer<CChar>) -> CInt = unlink
private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer<CNIOLinux_mmsghdr>?, CUnsignedInt, CInt) -> CInt = CNIOLinux_sendmmsg
private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer<CNIOLinux_mmsghdr>?, CUnsignedInt, CInt, UnsafeMutablePointer<timespec>?) -> CInt = CNIOLinux_recvmmsg
private let sysCmsgFirstHdr: @convention(c) (UnsafePointer<msghdr>?) -> UnsafeMutablePointer<cmsghdr>? =
CNIOLinux_CMSG_FIRSTHDR
private let sysCmsgNxtHdr: @convention(c) (UnsafeMutablePointer<msghdr>?, UnsafeMutablePointer<cmsghdr>?) ->
UnsafeMutablePointer<cmsghdr>? = CNIOLinux_CMSG_NXTHDR
private let sysCmsgData: @convention(c) (UnsafePointer<cmsghdr>?) -> UnsafeRawPointer? = CNIOLinux_CMSG_DATA
private let sysCmsgDataMutable: @convention(c) (UnsafeMutablePointer<cmsghdr>?) -> UnsafeMutableRawPointer? =
CNIOLinux_CMSG_DATA_MUTABLE
private let sysCmsgSpace: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_SPACE
private let sysCmsgLen: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_LEN
#elseif os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer<stat>?) -> CInt = fstat
private let sysStat: @convention(c) (UnsafePointer<CChar>?, UnsafeMutablePointer<stat>?) -> CInt = stat
@ -122,16 +113,6 @@ private let sysUnlink: @convention(c) (UnsafePointer<CChar>?) -> CInt = unlink
private let sysKevent = kevent
private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer<CNIODarwin_mmsghdr>?, CUnsignedInt, CInt) -> CInt = CNIODarwin_sendmmsg
private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer<CNIODarwin_mmsghdr>?, CUnsignedInt, CInt, UnsafeMutablePointer<timespec>?) -> CInt = CNIODarwin_recvmmsg
private let sysCmsgFirstHdr: @convention(c) (UnsafePointer<msghdr>?) -> UnsafeMutablePointer<cmsghdr>? =
CNIODarwin_CMSG_FIRSTHDR
private let sysCmsgNxtHdr: @convention(c) (UnsafePointer<msghdr>?, UnsafePointer<cmsghdr>?) ->
UnsafeMutablePointer<cmsghdr>? = CNIODarwin_CMSG_NXTHDR
private let sysCmsgData: @convention(c) (UnsafePointer<cmsghdr>?) -> UnsafeRawPointer? =
CNIODarwin_CMSG_DATA
private let sysCmsgDataMutable: @convention(c) (UnsafeMutablePointer<cmsghdr>?) -> UnsafeMutableRawPointer? =
CNIODarwin_CMSG_DATA_MUTABLE
private let sysCmsgSpace: @convention(c) (size_t) -> size_t = CNIODarwin_CMSG_SPACE
private let sysCmsgLen: @convention(c) (size_t) -> size_t = CNIODarwin_CMSG_LEN
#endif
private func isUnacceptableErrno(_ code: Int32) -> Bool {
@ -556,39 +537,6 @@ internal enum Posix {
sysSocketpair(domain.rawValue, type.rawValue, `protocol`, socketVector)
}
}
static func cmsgFirstHeader(inside msghdr: UnsafePointer<msghdr>) -> UnsafeMutablePointer<cmsghdr>? {
return sysCmsgFirstHdr(msghdr)
}
static func cmsgNextHeader(inside msghdr: UnsafeMutablePointer<msghdr>,
after: UnsafeMutablePointer<cmsghdr>) -> UnsafeMutablePointer<cmsghdr>? {
return sysCmsgNxtHdr(msghdr, after)
}
static func cmsgData(for header: UnsafePointer<cmsghdr>) -> UnsafeRawBufferPointer? {
let dataPointer = sysCmsgData(header)
// Linux and Darwin use different types for cmsg_len.
let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0)
let buffer = UnsafeRawBufferPointer(start: dataPointer, count: Int(length))
return buffer
}
static func cmsgData(for header: UnsafeMutablePointer<cmsghdr>) -> UnsafeMutableRawBufferPointer? {
let dataPointer = sysCmsgDataMutable(header)
// Linux and Darwin use different types for cmsg_len.
let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0)
let buffer = UnsafeMutableRawBufferPointer(start: dataPointer, count: Int(length))
return buffer
}
static func cmsgLen(payloadSize: size_t) -> size_t {
return sysCmsgLen(payloadSize)
}
static func cmsgSpace(payloadSize: size_t) -> size_t {
return sysCmsgSpace(payloadSize)
}
#endif
}

View File

@ -92,7 +92,7 @@ class SystemTest: XCTestCase {
msgHdr.msg_controllen = .init(pCmsgHdr.count)
withUnsafePointer(to: msgHdr) { pMsgHdr in
let result = Posix.cmsgFirstHeader(inside: pMsgHdr)
let result = NIOBSDSocketControlMessage.firstHeader(inside: pMsgHdr)
XCTAssertEqual(pCmsgHdr.baseAddress, result)
}
}
@ -106,11 +106,11 @@ class SystemTest: XCTestCase {
msgHdr.msg_controllen = .init(pCmsgHdr.count)
withUnsafeMutablePointer(to: &msgHdr) { pMsgHdr in
let first = Posix.cmsgFirstHeader(inside: pMsgHdr)
let second = Posix.cmsgNextHeader(inside: pMsgHdr, after: first!)
let first = NIOBSDSocketControlMessage.firstHeader(inside: pMsgHdr)
let second = NIOBSDSocketControlMessage.nextHeader(inside: pMsgHdr, after: first!)
let expectedSecondStart = pCmsgHdr.baseAddress! + SystemTest.cmsghdr_secondStartPosition
XCTAssertEqual(expectedSecondStart, second!)
let third = Posix.cmsgNextHeader(inside: pMsgHdr, after: second!)
let third = NIOBSDSocketControlMessage.nextHeader(inside: pMsgHdr, after: second!)
XCTAssertEqual(third, nil)
}
}
@ -124,8 +124,8 @@ class SystemTest: XCTestCase {
msgHdr.msg_controllen = .init(pCmsgHdr.count)
withUnsafePointer(to: msgHdr) { pMsgHdr in
let first = Posix.cmsgFirstHeader(inside: pMsgHdr)
let firstData = Posix.cmsgData(for: first!)
let first = NIOBSDSocketControlMessage.firstHeader(inside: pMsgHdr)
let firstData = NIOBSDSocketControlMessage.data(for: first!)
let expecedFirstData = UnsafeRawBufferPointer(
rebasing: pCmsgHdr[SystemTest.cmsghdr_firstDataStart..<(
SystemTest.cmsghdr_firstDataStart + SystemTest.cmsghdr_firstDataCount)])