From 86ce6aae8b58ecdec01e7ce5eb65c8b4783e7775 Mon Sep 17 00:00:00 2001 From: Peter Adams <63288215+PeterAdams-A@users.noreply.github.com> Date: Fri, 17 Jul 2020 16:25:09 +0100 Subject: [PATCH] Explicit Congestion Notification for UDP (#1596) Motivation: Network congestion can be controlled without packet loss if explicit congestion notification messages are understood. Modifications: Receive single and vector paths to get notification and surface. Send paths for both single and vector sends. Round trip tests. Storage for control messages added in a reusable way. I have tested allocations and they seem to be fine but the allocation tests for vector read and write were not stable enough to sensibly add to CI. Result: Explicit Congestion Notifications can be both sent and received. --- Sources/NIO/ControlMessage.swift | 58 ++++++++++++ Sources/NIO/DatagramVectorReadManager.swift | 55 +++++++++-- .../NIO/PendingDatagramWritesManager.swift | 19 +++- Sources/NIO/PipePair.swift | 2 +- Sources/NIO/SelectableEventLoop.swift | 5 + Sources/NIO/Socket.swift | 22 +++-- Sources/NIO/SocketChannel.swift | 52 +++++++++-- Sources/NIO/SocketProtocols.swift | 2 +- .../NIOTests/ControlMessageTests+XCTest.swift | 1 + Tests/NIOTests/ControlMessageTests.swift | 27 ++++++ .../DatagramChannelTests+XCTest.swift | 6 ++ Tests/NIOTests/DatagramChannelTests.swift | 93 ++++++++++++++++++- .../PendingDatagramWritesManagerTests.swift | 10 +- 13 files changed, 318 insertions(+), 34 deletions(-) diff --git a/Sources/NIO/ControlMessage.swift b/Sources/NIO/ControlMessage.swift index 4cdca048..d6acb474 100644 --- a/Sources/NIO/ControlMessage.swift +++ b/Sources/NIO/ControlMessage.swift @@ -18,6 +18,53 @@ import CNIODarwin import CNIOLinux #endif +/// Memory for use as `cmsghdr` and associated data. +/// Supports multiple messages each with enough storage for multiple `cmsghdr` +struct UnsafeControlMessageStorage: Collection { + let bytesPerMessage: Int + var buffer: UnsafeMutableRawBufferPointer + + /// Initialise which includes allocating memory + /// parameter: + /// - bytesPerMessage: How many bytes have been allocated for each supported message. + /// - buffer: The memory allocated to use for control messages. + private init(bytesPerMessage: Int, buffer: UnsafeMutableRawBufferPointer) { + self.bytesPerMessage = bytesPerMessage + self.buffer = buffer + } + + /// Allocate new memory - Caller must call `deallocate` when no longer required. + /// parameter: + /// - msghdrCount: How many `msghdr` structures will be fed from this buffer - we assume 4 Int32 cmsgs for each. + static func allocate(msghdrCount: Int) -> UnsafeControlMessageStorage { + // Guess that 4 Int32 payload messages is enough for anyone. + let bytesPerMessage = Posix.cmsgSpace(payloadSize: MemoryLayout.stride) * 4 + let buffer = UnsafeMutableRawBufferPointer.allocate(byteCount: bytesPerMessage * msghdrCount, + alignment: MemoryLayout.alignment) + return UnsafeControlMessageStorage(bytesPerMessage: bytesPerMessage, buffer: buffer) + } + + mutating func deallocate() { + self.buffer.deallocate() + self.buffer = UnsafeMutableRawBufferPointer(start: UnsafeMutableRawPointer(bitPattern: 0xdeadbeef), count: 0) + } + + /// Get the part of the buffer for use with a message. + public subscript(position: Int) -> UnsafeMutableRawBufferPointer { + return UnsafeMutableRawBufferPointer( + rebasing: self.buffer[(position * self.bytesPerMessage)..<((position+1) * self.bytesPerMessage)]) + } + + var startIndex: Int { return 0 } + + var endIndex: Int { return self.buffer.count / self.bytesPerMessage } + + func index(after: Int) -> Int { + return after + 1 + } + +} + /// Representation of a `cmsghdr` and associated data. /// Unsafe as captures pointers and must not escape the scope where those pointers are valid. struct UnsafeControlMessage { @@ -87,6 +134,17 @@ extension UnsafeControlMessageCollection: Collection { } } +/// Small struct to link a buffer used for control bytes and the processing of those bytes. +struct UnsafeReceivedControlBytes { + var controlBytesBuffer: UnsafeMutableRawBufferPointer + /// Set when a message is received which is using the controlBytesBuffer - the lifetime will be tied to that of `controlBytesBuffer` + var receivedControlMessages: UnsafeControlMessageCollection? + + init(controlBytesBuffer: UnsafeMutableRawBufferPointer) { + self.controlBytesBuffer = controlBytesBuffer + } +} + /// Extract information from a collection of control messages. struct ControlMessageParser { var ecnValue: NIOExplicitCongestionNotificationState = .transportNotCapable // Default diff --git a/Sources/NIO/DatagramVectorReadManager.swift b/Sources/NIO/DatagramVectorReadManager.swift index b7817649..68c05b80 100644 --- a/Sources/NIO/DatagramVectorReadManager.swift +++ b/Sources/NIO/DatagramVectorReadManager.swift @@ -35,10 +35,12 @@ struct DatagramVectorReadManager { self.messageVector.deinitializeAndDeallocate() self.ioVector.deinitializeAndDeallocate() self.sockaddrVector.deinitializeAndDeallocate() + self.controlMessageStorage.deallocate() self.messageVector = .allocateAndInitialize(repeating: MMsgHdr(msg_hdr: msghdr(), msg_len: 0), count: newValue) self.ioVector = .allocateAndInitialize(repeating: IOVector(), count: newValue) self.sockaddrVector = .allocateAndInitialize(repeating: sockaddr_storage(), count: newValue) + self.controlMessageStorage = UnsafeControlMessageStorage.allocate(msghdrCount: newValue) } } @@ -51,13 +53,20 @@ struct DatagramVectorReadManager { /// The vector of sockaddr structures used for saving addresses. private var sockaddrVector: UnsafeMutableBufferPointer + /// Storage to use for `cmsghdr` data when reading messages. + private var controlMessageStorage: UnsafeControlMessageStorage + // FIXME(cory): Right now there's no good API for specifying the various parameters of multi-read, especially how // it should interact with RecvByteBufferAllocator. For now I'm punting on this to see if I can get it working, // but we should design it back. - fileprivate init(messageVector: UnsafeMutableBufferPointer, ioVector: UnsafeMutableBufferPointer, sockaddrVector: UnsafeMutableBufferPointer) { + fileprivate init(messageVector: UnsafeMutableBufferPointer, + ioVector: UnsafeMutableBufferPointer, + sockaddrVector: UnsafeMutableBufferPointer, + controlMessageStorage: UnsafeControlMessageStorage) { self.messageVector = messageVector self.ioVector = ioVector self.sockaddrVector = sockaddrVector + self.controlMessageStorage = controlMessageStorage } /// Performs a socket vector read. @@ -78,7 +87,10 @@ struct DatagramVectorReadManager { /// - parameters: /// - socket: The underlying socket from which to read. /// - buffer: The single large buffer into which reads will be written. - func readFromSocket(socket: Socket, buffer: inout ByteBuffer) throws -> ReadResult { + /// - reportExplicitCongestionNotifications: Should explicit congestion notifications be reported up using metadata. + func readFromSocket(socket: Socket, + buffer: inout ByteBuffer, + reportExplicitCongestionNotifications: Bool) throws -> ReadResult { assert(buffer.readerIndex == 0, "Buffer was not cleared between calls to readFromSocket!") let messageSize = buffer.capacity / self.messageCount @@ -89,14 +101,22 @@ struct DatagramVectorReadManager { // First we set up the iovec and save it off. self.ioVector[i] = IOVector(iov_base: bufferPointer.baseAddress! + (i * messageSize), iov_len: messageSize) + + let controlBytes: UnsafeMutableRawBufferPointer + if reportExplicitCongestionNotifications { + // This will be used in buildMessages below but should not be used beyond return of this function. + controlBytes = self.controlMessageStorage[i] + } else { + controlBytes = UnsafeMutableRawBufferPointer(start: nil, count: 0) + } // Next we set up the msghdr structure. This points into the other vectors. let msgHdr = msghdr(msg_name: self.sockaddrVector.baseAddress! + i , msg_namelen: socklen_t(MemoryLayout.size), msg_iov: self.ioVector.baseAddress! + i, msg_iovlen: 1, // This is weird, but each message gets only one array. Duh. - msg_control: nil, - msg_controllen: 0, + msg_control: controlBytes.baseAddress, + msg_controllen: .init(controlBytes.count), msg_flags: 0) self.messageVector[i] = MMsgHdr(msg_hdr: msgHdr, msg_len: 0) @@ -116,7 +136,8 @@ struct DatagramVectorReadManager { buffer.moveWriterIndex(to: messageSize * messagesProcessed) return self.buildMessages(messageCount: messagesProcessed, sliceSize: messageSize, - buffer: &buffer) + buffer: &buffer, + reportExplicitCongestionNotifications: reportExplicitCongestionNotifications) } } @@ -125,9 +146,13 @@ struct DatagramVectorReadManager { self.messageVector.deinitializeAndDeallocate() self.ioVector.deinitializeAndDeallocate() self.sockaddrVector.deinitializeAndDeallocate() + self.controlMessageStorage.deallocate() } - private func buildMessages(messageCount: Int, sliceSize: Int, buffer: inout ByteBuffer) -> ReadResult { + private func buildMessages(messageCount: Int, + sliceSize: Int, + buffer: inout ByteBuffer, + reportExplicitCongestionNotifications: Bool) -> ReadResult { var sliceOffset = buffer.readerIndex var totalReadSize = 0 @@ -148,9 +173,19 @@ struct DatagramVectorReadManager { // Next we extract the remote peer address. precondition(self.messageVector[i].msg_hdr.msg_namelen != 0, "Unexpected zero length peer name") let address: SocketAddress = self.sockaddrVector[i].convert() + + // Extract congestion information if requested. + let metadata: AddressedEnvelope.Metadata? + if reportExplicitCongestionNotifications { + let controlMessagesReceived = + UnsafeControlMessageCollection(messageHeader: self.messageVector[i].msg_hdr) + metadata = .init(from: controlMessagesReceived) + } else { + metadata = nil + } // Now we've finally constructed a useful AddressedEnvelope. We can store it in the results array temporarily. - results.append(AddressedEnvelope(remoteAddress: address, data: slice)) + results.append(AddressedEnvelope(remoteAddress: address, data: slice, metadata: metadata)) } // Ok, all built. Now we can return these values to the caller. @@ -167,8 +202,12 @@ extension DatagramVectorReadManager { let messageVector = UnsafeMutableBufferPointer.allocateAndInitialize(repeating: MMsgHdr(msg_hdr: msghdr(), msg_len: 0), count: messageCount) let ioVector = UnsafeMutableBufferPointer.allocateAndInitialize(repeating: IOVector(), count: messageCount) let sockaddrVector = UnsafeMutableBufferPointer.allocateAndInitialize(repeating: sockaddr_storage(), count: messageCount) + let controlMessageStorage = UnsafeControlMessageStorage.allocate(msghdrCount: messageCount) - return DatagramVectorReadManager(messageVector: messageVector, ioVector: ioVector, sockaddrVector: sockaddrVector) + return DatagramVectorReadManager(messageVector: messageVector, + ioVector: ioVector, + sockaddrVector: sockaddrVector, + controlMessageStorage: controlMessageStorage) } } diff --git a/Sources/NIO/PendingDatagramWritesManager.swift b/Sources/NIO/PendingDatagramWritesManager.swift index c7d45f60..dfcac0c3 100644 --- a/Sources/NIO/PendingDatagramWritesManager.swift +++ b/Sources/NIO/PendingDatagramWritesManager.swift @@ -66,8 +66,11 @@ private func doPendingDatagramWriteVectorOperation(pending: PendingDatagramWrite msgs: UnsafeMutableBufferPointer, addresses: UnsafeMutableBufferPointer, storageRefs: UnsafeMutableBufferPointer>, + controlMessageStorage: UnsafeControlMessageStorage, _ body: (UnsafeMutableBufferPointer) throws -> IOResult) throws -> IOResult { assert(msgs.count >= Socket.writevLimitIOVectors, "Insufficiently sized buffer for a maximal sendmmsg") + assert(controlMessageStorage.count >= Socket.writevLimitIOVectors, + "Insufficiently sized control message storage for a maximal sendmmsg") // the numbers of storage refs that we need to decrease later. var c = 0 @@ -100,12 +103,16 @@ private func doPendingDatagramWriteVectorOperation(pending: PendingDatagramWrite let addressLen = p.copySocketAddress(addresses.baseAddress! + c) iovecs[c] = iovec(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer)) + var controlBytes = UnsafeOutboundControlBytes(controlBytes: controlMessageStorage[c]) + controlBytes.appendExplicitCongestionState(metadata: p.metadata, protocolFamily: p.address.protocol) + let controlMessageBytePointer = controlBytes.validControlBytes + let msg = msghdr(msg_name: addresses.baseAddress! + c, msg_namelen: addressLen, msg_iov: iovecs.baseAddress! + c, msg_iovlen: 1, - msg_control: nil, - msg_controllen: 0, + msg_control: controlMessageBytePointer.baseAddress, + msg_controllen: .init(controlMessageBytePointer.count), msg_flags: 0) msgs[c] = MMsgHdr(msg_hdr: msg, msg_len: CUnsignedInt(toWriteForThisBuffer)) } @@ -357,6 +364,8 @@ final class PendingDatagramWritesManager: PendingWritesManager { /// Storage for sockaddr structures. Only present on Linux because Darwin does not support gathering /// writes. private var addresses: UnsafeMutableBufferPointer + + private var controlMessageStorage: UnsafeControlMessageStorage private var state = PendingDatagramWritesState() @@ -375,14 +384,17 @@ final class PendingDatagramWritesManager: PendingWritesManager { /// - iovecs: A pre-allocated array of `IOVector` elements /// - addresses: A pre-allocated array of `sockaddr_storage` elements /// - storageRefs: A pre-allocated array of storage management tokens used to keep storage elements alive during a vector write operation + /// - controlMessageStorage: Pre-allocated memory for storing cmsghdr data during a vector write operation. init(msgs: UnsafeMutableBufferPointer, iovecs: UnsafeMutableBufferPointer, addresses: UnsafeMutableBufferPointer, - storageRefs: UnsafeMutableBufferPointer>) { + storageRefs: UnsafeMutableBufferPointer>, + controlMessageStorage: UnsafeControlMessageStorage) { self.msgs = msgs self.iovecs = iovecs self.addresses = addresses self.storageRefs = storageRefs + self.controlMessageStorage = controlMessageStorage } /// Mark the flush checkpoint. @@ -530,6 +542,7 @@ final class PendingDatagramWritesManager: PendingWritesManager { msgs: self.msgs, addresses: self.addresses, storageRefs: self.storageRefs, + controlMessageStorage: self.controlMessageStorage, { try vectorWriteOperation($0) }), messages: self.msgs) } diff --git a/Sources/NIO/PipePair.swift b/Sources/NIO/PipePair.swift index e0454347..7d0bc054 100644 --- a/Sources/NIO/PipePair.swift +++ b/Sources/NIO/PipePair.swift @@ -92,7 +92,7 @@ final class PipePair: SocketProtocol { func recvmsg(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t, - controlBytes: inout Slice) throws -> IOResult { + controlBytes: inout UnsafeReceivedControlBytes) throws -> IOResult { throw ChannelError.operationUnsupported } diff --git a/Sources/NIO/SelectableEventLoop.swift b/Sources/NIO/SelectableEventLoop.swift index 01b34e72..05dba475 100644 --- a/Sources/NIO/SelectableEventLoop.swift +++ b/Sources/NIO/SelectableEventLoop.swift @@ -85,6 +85,9 @@ internal final class SelectableEventLoop: EventLoop { private let _addresses: UnsafeMutablePointer let msgs: UnsafeMutableBufferPointer let addresses: UnsafeMutableBufferPointer + + // Used for UDP control messages. + private(set) var controlMessageStorage: UnsafeControlMessageStorage /// Creates a new `SelectableEventLoop` instance that is tied to the given `pthread_t`. @@ -143,6 +146,7 @@ internal final class SelectableEventLoop: EventLoop { self._addresses = UnsafeMutablePointer.allocate(capacity: Socket.writevLimitIOVectors) self.msgs = UnsafeMutableBufferPointer(start: _msgs, count: Socket.writevLimitIOVectors) self.addresses = UnsafeMutableBufferPointer(start: _addresses, count: Socket.writevLimitIOVectors) + self.controlMessageStorage = UnsafeControlMessageStorage.allocate(msghdrCount: Socket.writevLimitIOVectors) // We will process 4096 tasks per while loop. self.tasksCopy.reserveCapacity(4096) self.canBeShutdownIndividually = canBeShutdownIndividually @@ -157,6 +161,7 @@ internal final class SelectableEventLoop: EventLoop { _storageRefs.deallocate() _msgs.deallocate() _addresses.deallocate() + self.controlMessageStorage.deallocate() } /// Is this `SelectableEventLoop` still open (ie. not shutting down or shut down) diff --git a/Sources/NIO/Socket.swift b/Sources/NIO/Socket.swift index 06cc43d8..dc524380 100644 --- a/Sources/NIO/Socket.swift +++ b/Sources/NIO/Socket.swift @@ -204,17 +204,15 @@ typealias IOVector = iovec /// - pointer: The pointer (and size) to the storage into which the data should be read. /// - storage: The address from which the data was received /// - storageLen: The size of the storage itself. - /// - controlBytes: A region of a buffer into which control data can be written. This parameter will be modified on return to be - /// the slice of the data actually written into, if any. + /// - controlBytes: A buffer in memory for use receiving control bytes. This parameter will be modified to hold any data actually received. /// - returns: The `IOResult` which indicates how much data could be received and if the operation returned before all the data could be received /// (because the socket is in non-blocking mode) /// - throws: An `IOError` if the operation failed. func recvmsg(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t, - controlBytes: inout Slice) throws -> IOResult { + controlBytes: inout UnsafeReceivedControlBytes) throws -> IOResult { var vec = iovec(iov_base: pointer.baseAddress, iov_len: pointer.count) - let localControlBytePointer = UnsafeMutableRawBufferPointer(rebasing: controlBytes) return try withUnsafeMutablePointer(to: &vec) { vecPtr in return try storage.withMutableSockAddr { (sockaddrPtr, _) in @@ -222,20 +220,26 @@ typealias IOVector = iovec msg_namelen: storageLen, msg_iov: vecPtr, msg_iovlen: 1, - msg_control: localControlBytePointer.baseAddress, - msg_controllen: .init(localControlBytePointer.count), + msg_control: controlBytes.controlBytesBuffer.baseAddress, + msg_controllen: .init(controlBytes.controlBytesBuffer.count), msg_flags: 0) defer { - // We need to write back the length of the message and the control bytes. + // We need to write back the length of the message. storageLen = messageHeader.msg_namelen - controlBytes = controlBytes.prefix(.init(messageHeader.msg_controllen)) } - return try withUnsafeMutablePointer(to: &messageHeader) { messageHeader in + let result = try withUnsafeMutablePointer(to: &messageHeader) { messageHeader in return try withUnsafeHandle { fd in return try NIOBSDSocket.recvmsg(descriptor: fd, msgHdr: messageHeader, flags: 0) } } + + // Only look at the control bytes if all is good. + if case .processed = result { + controlBytes.receivedControlMessages = UnsafeControlMessageCollection(messageHeader: messageHeader) + } + + return result } } } diff --git a/Sources/NIO/SocketChannel.swift b/Sources/NIO/SocketChannel.swift index 1dca7f10..932862e7 100644 --- a/Sources/NIO/SocketChannel.swift +++ b/Sources/NIO/SocketChannel.swift @@ -325,6 +325,7 @@ final class ServerSocketChannel: BaseSocketChannel { /// /// Currently, it does not support connected mode which is well worth adding. final class DatagramChannel: BaseSocketChannel { + private var reportExplicitCongestionNotifications = false // Guard against re-entrance of flushNow() method. private let pendingWrites: PendingDatagramWritesManager @@ -372,7 +373,8 @@ final class DatagramChannel: BaseSocketChannel { self.pendingWrites = PendingDatagramWritesManager(msgs: eventLoop.msgs, iovecs: eventLoop.iovecs, addresses: eventLoop.addresses, - storageRefs: eventLoop.storageRefs) + storageRefs: eventLoop.storageRefs, + controlMessageStorage: eventLoop.controlMessageStorage) try super.init(socket: socket, parent: nil, @@ -386,7 +388,8 @@ final class DatagramChannel: BaseSocketChannel { self.pendingWrites = PendingDatagramWritesManager(msgs: eventLoop.msgs, iovecs: eventLoop.iovecs, addresses: eventLoop.addresses, - storageRefs: eventLoop.storageRefs) + storageRefs: eventLoop.storageRefs, + controlMessageStorage: eventLoop.controlMessageStorage) try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048)) } @@ -415,10 +418,12 @@ final class DatagramChannel: BaseSocketChannel { let valueAsInt: Int32 = value as! Bool ? 1 : 0 switch self.localAddress?.protocol { case .some(.inet): + self.reportExplicitCongestionNotifications = true try self.socket.setOption(level: .ip, name: .ip_recv_tos, value: valueAsInt) case .some(.inet6): + self.reportExplicitCongestionNotifications = true try self.socket.setOption(level: .ipv6, name: .ipv6_recv_tclass, value: valueAsInt) @@ -490,9 +495,13 @@ final class DatagramChannel: BaseSocketChannel { var buffer = self.recvAllocator.buffer(allocator: self.allocator) var readResult = ReadResult.none - // Right now we don't actually ask for any control messages. We will eventually. - let controlBytes = UnsafeMutableRawBufferPointer(start: nil, count: 0) - var controlByteSlice = controlBytes[...] + // These control bytes must not escape the current call stack + let controlBytesBuffer: UnsafeMutableRawBufferPointer + if self.reportExplicitCongestionNotifications { + controlBytesBuffer = self.selectableEventLoop.controlMessageStorage[0] + } else { + controlBytesBuffer = UnsafeMutableRawBufferPointer(start: nil, count: 0) + } for i in 1...self.maxMessagesPerRead { guard self.isOpen else { @@ -500,8 +509,13 @@ final class DatagramChannel: BaseSocketChannel { } buffer.clear() + var controlBytes = UnsafeReceivedControlBytes(controlBytesBuffer: controlBytesBuffer) + let result = try buffer.withMutableWritePointer { - try self.socket.recvmsg(pointer: $0, storage: &rawAddress, storageLen: &rawAddressLength, controlBytes: &controlByteSlice) + try self.socket.recvmsg(pointer: $0, + storage: &rawAddress, + storageLen: &rawAddressLength, + controlBytes: &controlBytes) } switch result { case .processed(let bytesRead): @@ -510,7 +524,17 @@ final class DatagramChannel: BaseSocketChannel { let mayGrow = recvAllocator.record(actualReadBytes: bytesRead) readPending = false - let msg = AddressedEnvelope(remoteAddress: rawAddress.convert(), data: buffer) + let metadata: AddressedEnvelope.Metadata? + if self.reportExplicitCongestionNotifications, + let controlMessagesReceived = controlBytes.receivedControlMessages { + metadata = .init(from: controlMessagesReceived) + } else { + metadata = nil + } + + let msg = AddressedEnvelope(remoteAddress: rawAddress.convert(), + data: buffer, + metadata: metadata) assert(self.isActive) pipeline.fireChannelRead0(NIOAny(msg)) if mayGrow && i < maxMessagesPerRead { @@ -542,7 +566,10 @@ final class DatagramChannel: BaseSocketChannel { buffer.clear() // This force-unwrap is safe, as we checked whether this is nil in the caller. - let result = try vectorReadManager.readFromSocket(socket: self.socket, buffer: &buffer) + let result = try vectorReadManager.readFromSocket( + socket: self.socket, + buffer: &buffer, + reportExplicitCongestionNotifications: self.reportExplicitCongestionNotifications) switch result { case .some(let results, let totalRead): assert(self.isOpen) @@ -623,11 +650,16 @@ final class DatagramChannel: BaseSocketChannel { return .processed(0) } // normal write - let controlBytes = UnsafeMutableRawBufferPointer(start: nil, count: 0) + // Control bytes must not escape current stack. + var controlBytes = UnsafeOutboundControlBytes( + controlBytes: self.selectableEventLoop.controlMessageStorage[0]) + controlBytes.appendExplicitCongestionState(metadata: metadata, + protocolFamily: self.localAddress?.protocol) return try self.socket.sendmsg(pointer: ptr, destinationPtr: destinationPtr, destinationSize: destinationSize, - controlBytes: controlBytes) + controlBytes: controlBytes.validControlBytes) + }, vectorWriteOperation: { msgs in return try self.socket.sendmmsg(msgs: msgs) diff --git a/Sources/NIO/SocketProtocols.swift b/Sources/NIO/SocketProtocols.swift index e65e999c..921dad50 100644 --- a/Sources/NIO/SocketProtocols.swift +++ b/Sources/NIO/SocketProtocols.swift @@ -50,7 +50,7 @@ protocol SocketProtocol: BaseSocketProtocol { func recvmsg(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t, - controlBytes: inout Slice) throws -> IOResult + controlBytes: inout UnsafeReceivedControlBytes) throws -> IOResult func sendmsg(pointer: UnsafeRawBufferPointer, destinationPtr: UnsafePointer, diff --git a/Tests/NIOTests/ControlMessageTests+XCTest.swift b/Tests/NIOTests/ControlMessageTests+XCTest.swift index 59063af2..1519fbf3 100644 --- a/Tests/NIOTests/ControlMessageTests+XCTest.swift +++ b/Tests/NIOTests/ControlMessageTests+XCTest.swift @@ -30,6 +30,7 @@ extension ControlMessageTests { ("testEmptyEncode", testEmptyEncode), ("testEncodeDecode1", testEncodeDecode1), ("testEncodeDecode2", testEncodeDecode2), + ("testStorageIndexing", testStorageIndexing), ] } } diff --git a/Tests/NIOTests/ControlMessageTests.swift b/Tests/NIOTests/ControlMessageTests.swift index 5a9d2f43..470e8e18 100644 --- a/Tests/NIOTests/ControlMessageTests.swift +++ b/Tests/NIOTests/ControlMessageTests.swift @@ -90,4 +90,31 @@ class ControlMessageTests: XCTestCase { } XCTAssertEqual(expected, decoded) } + + private func assertBuffersNonOverlapping(_ b1: UnsafeMutableRawBufferPointer, + _ b2: UnsafeMutableRawBufferPointer, + file: StaticString = #file, + line: UInt = #line) { + XCTAssert((b1.baseAddress! < b2.baseAddress! && (b1.baseAddress! + b1.count) <= b2.baseAddress!) || + (b2.baseAddress! < b1.baseAddress! && (b2.baseAddress! + b2.count) <= b1.baseAddress!), + file: (file), + line: line) + } + + func testStorageIndexing() { + var storage = UnsafeControlMessageStorage.allocate(msghdrCount: 3) + defer { + storage.deallocate() + } + // Check size + XCTAssertEqual(storage.count, 3) + // Buffers issued should not overlap. + assertBuffersNonOverlapping(storage[0], storage[1]) + assertBuffersNonOverlapping(storage[0], storage[2]) + assertBuffersNonOverlapping(storage[1], storage[2]) + // Buffers should have a suitable size. + XCTAssertGreaterThan(storage[0].count, MemoryLayout.stride) + XCTAssertGreaterThan(storage[1].count, MemoryLayout.stride) + XCTAssertGreaterThan(storage[2].count, MemoryLayout.stride) + } } diff --git a/Tests/NIOTests/DatagramChannelTests+XCTest.swift b/Tests/NIOTests/DatagramChannelTests+XCTest.swift index 6f556996..6fc3cee7 100644 --- a/Tests/NIOTests/DatagramChannelTests+XCTest.swift +++ b/Tests/NIOTests/DatagramChannelTests+XCTest.swift @@ -51,6 +51,12 @@ extension DatagramChannelTests { ("testMmsgWillTruncateWithoutChangeToAllocator", testMmsgWillTruncateWithoutChangeToAllocator), ("testRecvMmsgForMultipleCycles", testRecvMmsgForMultipleCycles), ("testSetGetEcnNotificationOption", testSetGetEcnNotificationOption), + ("testEcnSendReceiveIPV4", testEcnSendReceiveIPV4), + ("testEcnSendReceiveIPV6", testEcnSendReceiveIPV6), + ("testEcnSendReceiveIPV4VectorRead", testEcnSendReceiveIPV4VectorRead), + ("testEcnSendReceiveIPV6VectorRead", testEcnSendReceiveIPV6VectorRead), + ("testEcnSendReceiveIPV4VectorReadVectorWrite", testEcnSendReceiveIPV4VectorReadVectorWrite), + ("testEcnSendReceiveIPV6VectorReadVectorWrite", testEcnSendReceiveIPV6VectorReadVectorWrite), ] } } diff --git a/Tests/NIOTests/DatagramChannelTests.swift b/Tests/NIOTests/DatagramChannelTests.swift index 2ea639a1..61e507de 100644 --- a/Tests/NIOTests/DatagramChannelTests.swift +++ b/Tests/NIOTests/DatagramChannelTests.swift @@ -386,7 +386,11 @@ final class DatagramChannelTests: XCTestCase { try super.init(protocolFamily: .inet, type: .datagram) } - override func recvmsg(pointer: UnsafeMutableRawBufferPointer, storage: inout sockaddr_storage, storageLen: inout socklen_t, controlBytes: inout Slice) throws -> IOResult<(Int)> { + override func recvmsg(pointer: UnsafeMutableRawBufferPointer, + storage: inout sockaddr_storage, + storageLen: inout socklen_t, + controlBytes: inout UnsafeReceivedControlBytes) + throws -> IOResult<(Int)> { if let err = self.error { self.error = nil throw IOError(errnoCode: err, reason: "recvfrom") @@ -662,4 +666,91 @@ final class DatagramChannelTests: XCTestCase { XCTAssertFalse(try channel2.getOption(ChannelOptions.explicitCongestionNotification).wait()) } ()) } + + private func testEcnReceive(address: String, vectorRead: Bool, vectorSend: Bool) { + XCTAssertNoThrow(try { + let receiveBootstrap: DatagramBootstrap + if vectorRead { + receiveBootstrap = DatagramBootstrap(group: group) + .channelOption(ChannelOptions.datagramVectorReadMessageCount, value: 4) + } else { + receiveBootstrap = DatagramBootstrap(group: group) + } + + let receiveChannel = try receiveBootstrap + .channelOption(ChannelOptions.explicitCongestionNotification, value: true) + .channelInitializer { channel in + channel.pipeline.addHandler(DatagramReadRecorder(), name: "ByteReadRecorder") + } + .bind(host: address, port: 0) + .wait() + defer { + XCTAssertNoThrow(try receiveChannel.close().wait()) + } + let sendChannel = try DatagramBootstrap(group: group) + .bind(host: address, port: 0) + .wait() + defer { + XCTAssertNoThrow(try sendChannel.close().wait()) + } + + var buffer = sendChannel.allocator.buffer(capacity: 1) + buffer.writeRepeatingByte(0, count: 1) + let ecnStates: [NIOExplicitCongestionNotificationState] = [.transportNotCapable, + .congestionExperienced, + .transportCapableFlag0, + .transportCapableFlag1] + for ecnState in ecnStates { + let writeData = AddressedEnvelope(remoteAddress: receiveChannel.localAddress!, + data: buffer, + metadata: .init(ecnState: ecnState)) + // Sending extra data without flushing should trigger a vector send. + if (vectorSend) { + sendChannel.write(writeData, promise: nil) + } + try sendChannel.writeAndFlush(writeData).wait() + } + + let expectedReads = ecnStates.count * (vectorSend ? 2 : 1) + let reads = try receiveChannel.waitForDatagrams(count: expectedReads) + XCTAssertEqual(reads.count, expectedReads) + for readNumber in 0..] = Array(repeating: Unmanaged.passUnretained(o), count: Socket.writevLimitIOVectors + 1) var msgs: [MMsgHdr] = Array(repeating: MMsgHdr(), count: Socket.writevLimitIOVectors + 1) var addresses: [sockaddr_storage] = Array(repeating: sockaddr_storage(), count: Socket.writevLimitIOVectors + 1) + var controlMessageStorage = UnsafeControlMessageStorage.allocate(msghdrCount: Socket.writevLimitIOVectors) + defer { + controlMessageStorage.deallocate() + } /* put a canary value at the end */ iovecs[iovecs.count - 1] = iovec(iov_base: UnsafeMutableRawPointer(bitPattern: 0xdeadbee)!, iov_len: 0xdeadbee) try iovecs.withUnsafeMutableBufferPointer { iovecs in try managed.withUnsafeMutableBufferPointer { managed in try msgs.withUnsafeMutableBufferPointer { msgs in try addresses.withUnsafeMutableBufferPointer { addresses in - let pwm = NIO.PendingDatagramWritesManager(msgs: msgs, iovecs: iovecs, addresses: addresses, storageRefs: managed) + let pwm = NIO.PendingDatagramWritesManager(msgs: msgs, + iovecs: iovecs, + addresses: addresses, + storageRefs: managed, + controlMessageStorage: controlMessageStorage) XCTAssertTrue(pwm.isEmpty) XCTAssertTrue(pwm.isOpen) XCTAssertFalse(pwm.isFlushPending)