From e3508b0d04783041c791cf6238b13cd4f35eae73 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Wed, 8 Jul 2020 09:10:32 +0100 Subject: [PATCH] Move DatagramChannel over to using recvmsg. (#1473) Motivation: When attempting to obtain metadata about a read, we need to use recvmsg in order to obtain that data. While we're using recvmmsg for vector reads right now, our scalar reads use recvfrom, which does not provide us with that metadata. While we're not extracting metadata right now, we may well do so in future, so we can apply it now. Modifications: - Move recvfrom usage to recvmsg. Result: We'll be able to extend recvmsg to extract metadata. --- Sources/NIO/BSDSocketAPI.swift | 6 +-- Sources/NIO/BSDSocketAPIPosix.swift | 12 +----- Sources/NIO/BSDSocketAPIWindows.swift | 12 +----- Sources/NIO/PipePair.swift | 5 ++- Sources/NIO/Socket.swift | 41 +++++++++++++++---- Sources/NIO/SocketChannel.swift | 6 ++- Sources/NIO/SocketProtocols.swift | 5 ++- Sources/NIO/System.swift | 5 ++- .../DatagramChannelTests+XCTest.swift | 6 +-- Tests/NIOTests/DatagramChannelTests.swift | 16 ++++---- 10 files changed, 64 insertions(+), 50 deletions(-) diff --git a/Sources/NIO/BSDSocketAPI.swift b/Sources/NIO/BSDSocketAPI.swift index 1fde11ea..263e2bcd 100644 --- a/Sources/NIO/BSDSocketAPI.swift +++ b/Sources/NIO/BSDSocketAPI.swift @@ -376,11 +376,7 @@ protocol _BSDSocketProtocol { buffer buf: UnsafeMutableRawPointer, length len: size_t) throws -> IOResult - static func recvfrom(socket s: NIOBSDSocket.Handle, - buffer buf: UnsafeMutableRawPointer, - length len: size_t, - address from: UnsafeMutablePointer, - address_len fromlen: UnsafeMutablePointer) throws -> IOResult + static func recvmsg(descriptor: CInt, msgHdr: UnsafeMutablePointer, flags: CInt) throws -> IOResult static func send(socket s: NIOBSDSocket.Handle, buffer buf: UnsafeRawPointer, diff --git a/Sources/NIO/BSDSocketAPIPosix.swift b/Sources/NIO/BSDSocketAPIPosix.swift index de4d42e3..4c13ae13 100644 --- a/Sources/NIO/BSDSocketAPIPosix.swift +++ b/Sources/NIO/BSDSocketAPIPosix.swift @@ -84,16 +84,8 @@ extension NIOBSDSocket { return try Posix.read(descriptor: s, pointer: buf, size: len) } - static func recvfrom(socket s: NIOBSDSocket.Handle, - buffer buf: UnsafeMutableRawPointer, - length len: size_t, - address from: UnsafeMutablePointer, - address_len fromlen: UnsafeMutablePointer) throws -> IOResult { - return try Posix.recvfrom(descriptor: s, - pointer: buf, - len: len, - addr: from, - addrlen: fromlen) + static func recvmsg(descriptor: CInt, msgHdr: UnsafeMutablePointer, flags: CInt) throws -> IOResult { + return try Posix.recvmsg(descriptor: descriptor, msgHdr: msgHdr, flags: flags) } static func send(socket s: NIOBSDSocket.Handle, diff --git a/Sources/NIO/BSDSocketAPIWindows.swift b/Sources/NIO/BSDSocketAPIWindows.swift index 6597370e..0807497b 100644 --- a/Sources/NIO/BSDSocketAPIWindows.swift +++ b/Sources/NIO/BSDSocketAPIWindows.swift @@ -161,16 +161,8 @@ extension NIOBSDSocket { } @inline(never) - static func recvfrom(socket s: NIOBSDSocket.Handle, - buffer buf: UnsafeMutableRawPointer, - length len: size_t, - address from: UnsafeMutablePointer, - address_len fromlen: UnsafeMutablePointer) throws -> IOResult { - let iResult: CInt = CNIOWindows_recvfrom(s, buf, CInt(len), 0, from, fromlen) - if iResult == SOCKET_ERROR { - throw IOError(winsock: WSAGetLastError(), reason: "recvfrom") - } - return .processed(size_t(iResult)) + static func recvmsg(descriptor: CInt, msgHdr: UnsafeMutablePointer, flags: CInt) throws -> IOResult { + fatalError("recvmsg not yet implemented on Windows") } @inline(never) diff --git a/Sources/NIO/PipePair.swift b/Sources/NIO/PipePair.swift index 69e0e9b3..a5b7d9e3 100644 --- a/Sources/NIO/PipePair.swift +++ b/Sources/NIO/PipePair.swift @@ -93,7 +93,10 @@ final class PipePair: SocketProtocol { } } - func recvfrom(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t) throws -> IOResult { + func recvmsg(pointer: UnsafeMutableRawBufferPointer, + storage: inout sockaddr_storage, + storageLen: inout socklen_t, + controlBytes: inout Slice) throws -> IOResult { throw ChannelError.operationUnsupported } diff --git a/Sources/NIO/Socket.swift b/Sources/NIO/Socket.swift index f5d1ee56..5d2edbc4 100644 --- a/Sources/NIO/Socket.swift +++ b/Sources/NIO/Socket.swift @@ -182,21 +182,44 @@ typealias IOVector = iovec } } - /// Receive data from the socket. + /// Receive data from the socket, along with aditional control information. /// /// - parameters: /// - pointer: The pointer (and size) to the storage into which the data should be read. /// - storage: The address from which the data was received /// - storageLen: The size of the storage itself. - /// - returns: The `IOResult` which indicates how much data could be received and if the operation returned before all could be received (because the socket is in non-blocking mode). + /// - controlBytes: A region of a buffer into which control data can be written. This parameter will be modified on return to be + /// the slice of the data actually written into, if any. + /// - returns: The `IOResult` which indicates how much data could be received and if the operation returned before all the data could be received + /// (because the socket is in non-blocking mode) /// - throws: An `IOError` if the operation failed. - func recvfrom(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t) throws -> IOResult<(Int)> { - return try withUnsafeHandle { fd in - try storage.withMutableSockAddr { (storagePtr, _) in - try NIOBSDSocket.recvfrom(socket: fd, buffer: pointer.baseAddress!, - length: pointer.count, - address: storagePtr, - address_len: &storageLen) + func recvmsg(pointer: UnsafeMutableRawBufferPointer, + storage: inout sockaddr_storage, + storageLen: inout socklen_t, + controlBytes: inout Slice) throws -> IOResult { + var vec = iovec(iov_base: pointer.baseAddress, iov_len: pointer.count) + let localControlBytePointer = UnsafeMutableRawBufferPointer(rebasing: controlBytes) + + return try withUnsafeMutablePointer(to: &vec) { vecPtr in + return try storage.withMutableSockAddr { (sockaddrPtr, _) in + var messageHeader = msghdr(msg_name: sockaddrPtr, + msg_namelen: storageLen, + msg_iov: vecPtr, + msg_iovlen: 1, + msg_control: localControlBytePointer.baseAddress, + msg_controllen: .init(localControlBytePointer.count), + msg_flags: 0) + defer { + // We need to write back the length of the message and the control bytes. + storageLen = messageHeader.msg_namelen + controlBytes = controlBytes.prefix(.init(messageHeader.msg_controllen)) + } + + return try withUnsafeMutablePointer(to: &messageHeader) { messageHeader in + return try withUnsafeHandle { fd in + return try NIOBSDSocket.recvmsg(descriptor: fd, msgHdr: messageHeader, flags: 0) + } + } } } } diff --git a/Sources/NIO/SocketChannel.swift b/Sources/NIO/SocketChannel.swift index 347ac24c..1cc563bc 100644 --- a/Sources/NIO/SocketChannel.swift +++ b/Sources/NIO/SocketChannel.swift @@ -490,6 +490,10 @@ final class DatagramChannel: BaseSocketChannel { var buffer = self.recvAllocator.buffer(allocator: self.allocator) var readResult = ReadResult.none + // Right now we don't actually ask for any control messages. We will eventually. + let controlBytes = UnsafeMutableRawBufferPointer(start: nil, count: 0) + var controlByteSlice = controlBytes[...] + for i in 1...self.maxMessagesPerRead { guard self.isOpen else { throw ChannelError.eof @@ -497,7 +501,7 @@ final class DatagramChannel: BaseSocketChannel { buffer.clear() let result = try buffer.withMutableWritePointer { - try self.socket.recvfrom(pointer: $0, storage: &rawAddress, storageLen: &rawAddressLength) + try self.socket.recvmsg(pointer: $0, storage: &rawAddress, storageLen: &rawAddressLength, controlBytes: &controlByteSlice) } switch result { case .processed(let bytesRead): diff --git a/Sources/NIO/SocketProtocols.swift b/Sources/NIO/SocketProtocols.swift index 7bdb20f5..3ca3a115 100644 --- a/Sources/NIO/SocketProtocols.swift +++ b/Sources/NIO/SocketProtocols.swift @@ -49,7 +49,10 @@ protocol SocketProtocol: BaseSocketProtocol { func read(pointer: UnsafeMutableRawBufferPointer) throws -> IOResult - func recvfrom(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t) throws -> IOResult + func recvmsg(pointer: UnsafeMutableRawBufferPointer, + storage: inout sockaddr_storage, + storageLen: inout socklen_t, + controlBytes: inout Slice) throws -> IOResult func sendFile(fd: Int32, offset: Int, count: Int) throws -> IOResult diff --git a/Sources/NIO/System.swift b/Sources/NIO/System.swift index 7326f153..0241ad96 100644 --- a/Sources/NIO/System.swift +++ b/Sources/NIO/System.swift @@ -82,6 +82,7 @@ private let sysRecvFrom: @convention(c) (CInt, UnsafeMutableRawPointer?, CLong, private let sysWritev: @convention(c) (Int32, UnsafePointer?, CInt) -> CLong = writev #endif private let sysSendTo: @convention(c) (CInt, UnsafeRawPointer?, CLong, CInt, UnsafePointer?, socklen_t) -> CLong = sendto +private let sysRecvMsg: @convention(c) (CInt, UnsafeMutablePointer?, CInt) -> ssize_t = recvmsg private let sysDup: @convention(c) (CInt) -> CInt = dup private let sysGetpeername: @convention(c) (CInt, UnsafeMutablePointer?, UnsafeMutablePointer?) -> CInt = getpeername private let sysGetsockname: @convention(c) (CInt, UnsafeMutablePointer?, UnsafeMutablePointer?) -> CInt = getsockname @@ -373,9 +374,9 @@ internal enum Posix { } @inline(never) - public static func recvfrom(descriptor: CInt, pointer: UnsafeMutableRawPointer, len: size_t, addr: UnsafeMutablePointer, addrlen: UnsafeMutablePointer) throws -> IOResult { + public static func recvmsg(descriptor: CInt, msgHdr: UnsafeMutablePointer, flags: CInt) throws -> IOResult { return try syscall(blocking: true) { - sysRecvFrom(descriptor, pointer, len, 0, addr, addrlen) + sysRecvMsg(descriptor, msgHdr, flags) } } diff --git a/Tests/NIOTests/DatagramChannelTests+XCTest.swift b/Tests/NIOTests/DatagramChannelTests+XCTest.swift index f71c4365..6f556996 100644 --- a/Tests/NIOTests/DatagramChannelTests+XCTest.swift +++ b/Tests/NIOTests/DatagramChannelTests+XCTest.swift @@ -37,9 +37,9 @@ extension DatagramChannelTests { ("testLargeWritesFail", testLargeWritesFail), ("testOneLargeWriteDoesntPreventOthersWriting", testOneLargeWriteDoesntPreventOthersWriting), ("testClosingBeforeFlushFailsAllWrites", testClosingBeforeFlushFailsAllWrites), - ("testRecvFromFailsWithECONNREFUSED", testRecvFromFailsWithECONNREFUSED), - ("testRecvFromFailsWithENOMEM", testRecvFromFailsWithENOMEM), - ("testRecvFromFailsWithEFAULT", testRecvFromFailsWithEFAULT), + ("testRecvMsgFailsWithECONNREFUSED", testRecvMsgFailsWithECONNREFUSED), + ("testRecvMsgFailsWithENOMEM", testRecvMsgFailsWithENOMEM), + ("testRecvMsgFailsWithEFAULT", testRecvMsgFailsWithEFAULT), ("testRecvMmsgFailsWithECONNREFUSED", testRecvMmsgFailsWithECONNREFUSED), ("testRecvMmsgFailsWithENOMEM", testRecvMmsgFailsWithENOMEM), ("testRecvMmsgFailsWithEFAULT", testRecvMmsgFailsWithEFAULT), diff --git a/Tests/NIOTests/DatagramChannelTests.swift b/Tests/NIOTests/DatagramChannelTests.swift index 595e4651..2ea639a1 100644 --- a/Tests/NIOTests/DatagramChannelTests.swift +++ b/Tests/NIOTests/DatagramChannelTests.swift @@ -340,19 +340,19 @@ final class DatagramChannelTests: XCTestCase { } } - public func testRecvFromFailsWithECONNREFUSED() throws { - try assertRecvFromFails(error: ECONNREFUSED, active: true) + public func testRecvMsgFailsWithECONNREFUSED() throws { + try assertRecvMsgFails(error: ECONNREFUSED, active: true) } - public func testRecvFromFailsWithENOMEM() throws { - try assertRecvFromFails(error: ENOMEM, active: true) + public func testRecvMsgFailsWithENOMEM() throws { + try assertRecvMsgFails(error: ENOMEM, active: true) } - public func testRecvFromFailsWithEFAULT() throws { - try assertRecvFromFails(error: EFAULT, active: false) + public func testRecvMsgFailsWithEFAULT() throws { + try assertRecvMsgFails(error: EFAULT, active: false) } - private func assertRecvFromFails(error: Int32, active: Bool) throws { + private func assertRecvMsgFails(error: Int32, active: Bool) throws { final class RecvFromHandler: ChannelInboundHandler { typealias InboundIn = AddressedEnvelope typealias InboundOut = AddressedEnvelope @@ -386,7 +386,7 @@ final class DatagramChannelTests: XCTestCase { try super.init(protocolFamily: .inet, type: .datagram) } - override func recvfrom(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t) throws -> IOResult<(Int)> { + override func recvmsg(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t, controlBytes: inout Slice) throws -> IOResult<(Int)> { if let err = self.error { self.error = nil throw IOError(errnoCode: err, reason: "recvfrom")