Move DatagramChannel over to using recvmsg. (#1473)

Motivation:

When attempting to obtain metadata about a read, we need to use recvmsg
in order to obtain that data. While we're using recvmmsg for vector
reads right now, our scalar reads use recvfrom, which does not provide
us with that metadata.

While we're not extracting metadata right now, we may well do so in
future, so we can apply it now.

Modifications:

- Move recvfrom usage to recvmsg.

Result:

We'll be able to extend recvmsg to extract metadata.
This commit is contained in:
Cory Benfield 2020-07-08 09:10:32 +01:00 committed by GitHub
parent 5de1e41310
commit e3508b0d04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 64 additions and 50 deletions

View File

@ -376,11 +376,7 @@ protocol _BSDSocketProtocol {
buffer buf: UnsafeMutableRawPointer,
length len: size_t) throws -> IOResult<size_t>
static func recvfrom(socket s: NIOBSDSocket.Handle,
buffer buf: UnsafeMutableRawPointer,
length len: size_t,
address from: UnsafeMutablePointer<sockaddr>,
address_len fromlen: UnsafeMutablePointer<socklen_t>) throws -> IOResult<size_t>
static func recvmsg(descriptor: CInt, msgHdr: UnsafeMutablePointer<msghdr>, flags: CInt) throws -> IOResult<ssize_t>
static func send(socket s: NIOBSDSocket.Handle,
buffer buf: UnsafeRawPointer,

View File

@ -84,16 +84,8 @@ extension NIOBSDSocket {
return try Posix.read(descriptor: s, pointer: buf, size: len)
}
static func recvfrom(socket s: NIOBSDSocket.Handle,
buffer buf: UnsafeMutableRawPointer,
length len: size_t,
address from: UnsafeMutablePointer<sockaddr>,
address_len fromlen: UnsafeMutablePointer<socklen_t>) throws -> IOResult<size_t> {
return try Posix.recvfrom(descriptor: s,
pointer: buf,
len: len,
addr: from,
addrlen: fromlen)
static func recvmsg(descriptor: CInt, msgHdr: UnsafeMutablePointer<msghdr>, flags: CInt) throws -> IOResult<ssize_t> {
return try Posix.recvmsg(descriptor: descriptor, msgHdr: msgHdr, flags: flags)
}
static func send(socket s: NIOBSDSocket.Handle,

View File

@ -161,16 +161,8 @@ extension NIOBSDSocket {
}
@inline(never)
static func recvfrom(socket s: NIOBSDSocket.Handle,
buffer buf: UnsafeMutableRawPointer,
length len: size_t,
address from: UnsafeMutablePointer<sockaddr>,
address_len fromlen: UnsafeMutablePointer<socklen_t>) throws -> IOResult<size_t> {
let iResult: CInt = CNIOWindows_recvfrom(s, buf, CInt(len), 0, from, fromlen)
if iResult == SOCKET_ERROR {
throw IOError(winsock: WSAGetLastError(), reason: "recvfrom")
}
return .processed(size_t(iResult))
static func recvmsg(descriptor: CInt, msgHdr: UnsafeMutablePointer<msghdr>, flags: CInt) throws -> IOResult<ssize_t> {
fatalError("recvmsg not yet implemented on Windows")
}
@inline(never)

View File

@ -93,7 +93,10 @@ final class PipePair: SocketProtocol {
}
}
func recvfrom(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t) throws -> IOResult<Int> {
func recvmsg(pointer: UnsafeMutableRawBufferPointer,
storage: inout sockaddr_storage,
storageLen: inout socklen_t,
controlBytes: inout Slice<UnsafeMutableRawBufferPointer>) throws -> IOResult<Int> {
throw ChannelError.operationUnsupported
}

View File

@ -182,21 +182,44 @@ typealias IOVector = iovec
}
}
/// Receive data from the socket.
/// Receive data from the socket, along with aditional control information.
///
/// - parameters:
/// - pointer: The pointer (and size) to the storage into which the data should be read.
/// - storage: The address from which the data was received
/// - storageLen: The size of the storage itself.
/// - returns: The `IOResult` which indicates how much data could be received and if the operation returned before all could be received (because the socket is in non-blocking mode).
/// - controlBytes: A region of a buffer into which control data can be written. This parameter will be modified on return to be
/// the slice of the data actually written into, if any.
/// - returns: The `IOResult` which indicates how much data could be received and if the operation returned before all the data could be received
/// (because the socket is in non-blocking mode)
/// - throws: An `IOError` if the operation failed.
func recvfrom(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t) throws -> IOResult<(Int)> {
return try withUnsafeHandle { fd in
try storage.withMutableSockAddr { (storagePtr, _) in
try NIOBSDSocket.recvfrom(socket: fd, buffer: pointer.baseAddress!,
length: pointer.count,
address: storagePtr,
address_len: &storageLen)
func recvmsg(pointer: UnsafeMutableRawBufferPointer,
storage: inout sockaddr_storage,
storageLen: inout socklen_t,
controlBytes: inout Slice<UnsafeMutableRawBufferPointer>) throws -> IOResult<Int> {
var vec = iovec(iov_base: pointer.baseAddress, iov_len: pointer.count)
let localControlBytePointer = UnsafeMutableRawBufferPointer(rebasing: controlBytes)
return try withUnsafeMutablePointer(to: &vec) { vecPtr in
return try storage.withMutableSockAddr { (sockaddrPtr, _) in
var messageHeader = msghdr(msg_name: sockaddrPtr,
msg_namelen: storageLen,
msg_iov: vecPtr,
msg_iovlen: 1,
msg_control: localControlBytePointer.baseAddress,
msg_controllen: .init(localControlBytePointer.count),
msg_flags: 0)
defer {
// We need to write back the length of the message and the control bytes.
storageLen = messageHeader.msg_namelen
controlBytes = controlBytes.prefix(.init(messageHeader.msg_controllen))
}
return try withUnsafeMutablePointer(to: &messageHeader) { messageHeader in
return try withUnsafeHandle { fd in
return try NIOBSDSocket.recvmsg(descriptor: fd, msgHdr: messageHeader, flags: 0)
}
}
}
}
}

View File

@ -490,6 +490,10 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
var buffer = self.recvAllocator.buffer(allocator: self.allocator)
var readResult = ReadResult.none
// Right now we don't actually ask for any control messages. We will eventually.
let controlBytes = UnsafeMutableRawBufferPointer(start: nil, count: 0)
var controlByteSlice = controlBytes[...]
for i in 1...self.maxMessagesPerRead {
guard self.isOpen else {
throw ChannelError.eof
@ -497,7 +501,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
buffer.clear()
let result = try buffer.withMutableWritePointer {
try self.socket.recvfrom(pointer: $0, storage: &rawAddress, storageLen: &rawAddressLength)
try self.socket.recvmsg(pointer: $0, storage: &rawAddress, storageLen: &rawAddressLength, controlBytes: &controlByteSlice)
}
switch result {
case .processed(let bytesRead):

View File

@ -49,7 +49,10 @@ protocol SocketProtocol: BaseSocketProtocol {
func read(pointer: UnsafeMutableRawBufferPointer) throws -> IOResult<Int>
func recvfrom(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t) throws -> IOResult<Int>
func recvmsg(pointer: UnsafeMutableRawBufferPointer,
storage: inout sockaddr_storage,
storageLen: inout socklen_t,
controlBytes: inout Slice<UnsafeMutableRawBufferPointer>) throws -> IOResult<Int>
func sendFile(fd: Int32, offset: Int, count: Int) throws -> IOResult<Int>

View File

@ -82,6 +82,7 @@ private let sysRecvFrom: @convention(c) (CInt, UnsafeMutableRawPointer?, CLong,
private let sysWritev: @convention(c) (Int32, UnsafePointer<iovec>?, CInt) -> CLong = writev
#endif
private let sysSendTo: @convention(c) (CInt, UnsafeRawPointer?, CLong, CInt, UnsafePointer<sockaddr>?, socklen_t) -> CLong = sendto
private let sysRecvMsg: @convention(c) (CInt, UnsafeMutablePointer<msghdr>?, CInt) -> ssize_t = recvmsg
private let sysDup: @convention(c) (CInt) -> CInt = dup
private let sysGetpeername: @convention(c) (CInt, UnsafeMutablePointer<sockaddr>?, UnsafeMutablePointer<socklen_t>?) -> CInt = getpeername
private let sysGetsockname: @convention(c) (CInt, UnsafeMutablePointer<sockaddr>?, UnsafeMutablePointer<socklen_t>?) -> CInt = getsockname
@ -373,9 +374,9 @@ internal enum Posix {
}
@inline(never)
public static func recvfrom(descriptor: CInt, pointer: UnsafeMutableRawPointer, len: size_t, addr: UnsafeMutablePointer<sockaddr>, addrlen: UnsafeMutablePointer<socklen_t>) throws -> IOResult<ssize_t> {
public static func recvmsg(descriptor: CInt, msgHdr: UnsafeMutablePointer<msghdr>, flags: CInt) throws -> IOResult<ssize_t> {
return try syscall(blocking: true) {
sysRecvFrom(descriptor, pointer, len, 0, addr, addrlen)
sysRecvMsg(descriptor, msgHdr, flags)
}
}

View File

@ -37,9 +37,9 @@ extension DatagramChannelTests {
("testLargeWritesFail", testLargeWritesFail),
("testOneLargeWriteDoesntPreventOthersWriting", testOneLargeWriteDoesntPreventOthersWriting),
("testClosingBeforeFlushFailsAllWrites", testClosingBeforeFlushFailsAllWrites),
("testRecvFromFailsWithECONNREFUSED", testRecvFromFailsWithECONNREFUSED),
("testRecvFromFailsWithENOMEM", testRecvFromFailsWithENOMEM),
("testRecvFromFailsWithEFAULT", testRecvFromFailsWithEFAULT),
("testRecvMsgFailsWithECONNREFUSED", testRecvMsgFailsWithECONNREFUSED),
("testRecvMsgFailsWithENOMEM", testRecvMsgFailsWithENOMEM),
("testRecvMsgFailsWithEFAULT", testRecvMsgFailsWithEFAULT),
("testRecvMmsgFailsWithECONNREFUSED", testRecvMmsgFailsWithECONNREFUSED),
("testRecvMmsgFailsWithENOMEM", testRecvMmsgFailsWithENOMEM),
("testRecvMmsgFailsWithEFAULT", testRecvMmsgFailsWithEFAULT),

View File

@ -340,19 +340,19 @@ final class DatagramChannelTests: XCTestCase {
}
}
public func testRecvFromFailsWithECONNREFUSED() throws {
try assertRecvFromFails(error: ECONNREFUSED, active: true)
public func testRecvMsgFailsWithECONNREFUSED() throws {
try assertRecvMsgFails(error: ECONNREFUSED, active: true)
}
public func testRecvFromFailsWithENOMEM() throws {
try assertRecvFromFails(error: ENOMEM, active: true)
public func testRecvMsgFailsWithENOMEM() throws {
try assertRecvMsgFails(error: ENOMEM, active: true)
}
public func testRecvFromFailsWithEFAULT() throws {
try assertRecvFromFails(error: EFAULT, active: false)
public func testRecvMsgFailsWithEFAULT() throws {
try assertRecvMsgFails(error: EFAULT, active: false)
}
private func assertRecvFromFails(error: Int32, active: Bool) throws {
private func assertRecvMsgFails(error: Int32, active: Bool) throws {
final class RecvFromHandler: ChannelInboundHandler {
typealias InboundIn = AddressedEnvelope<ByteBuffer>
typealias InboundOut = AddressedEnvelope<ByteBuffer>
@ -386,7 +386,7 @@ final class DatagramChannelTests: XCTestCase {
try super.init(protocolFamily: .inet, type: .datagram)
}
override func recvfrom(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t) throws -> IOResult<(Int)> {
override func recvmsg(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t, controlBytes: inout Slice<UnsafeMutableRawBufferPointer>) throws -> IOResult<(Int)> {
if let err = self.error {
self.error = nil
throw IOError(errnoCode: err, reason: "recvfrom")