From e7e83d6aa4cc491d9416a40ea3aad7a0cc8cc778 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Thu, 6 Apr 2023 13:26:32 +0100 Subject: [PATCH] Land `NIOAsyncChannel` as SPI (#2397) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Land `NIOAsyncChannel` as SPI # Motivation We want to provide bridges from NIO `Channel`s to Swift Concurrency. In previous PRs, we already landed the building blocks namely `NIOAsyncSequenceProducer` and `NIOAsyncWriter`. These two types are highly performant bridges between synchronous and asynchronous code that respect back-pressure. The next step is to build convenience methods that wrap a `Channel` with these two types. # Modification This PR adds a new type called `NIOAsyncChannel` that is capable of wrapping a `Channel`. This is done by adding two handlers to the channel pipeline that are bridging to the `NIOAsyncSequenceProducer` and `NIOAsyncWriter`. The new `NIOAsyncChannel` type exposes three properties. The underlying `Channel`, a `NIOAsyncChannelInboundStream` and a `NIOAsyncChannelOutboundWriter`. Using these three types the user a able to read/write into the channel using `async` methods. Importantly, we are landing all of this behind the `@_spi(AsyncChannel`. This allows us to merge PRs while we are still working on the remaining parts such as protocol negotiation. # Result We have the first part necessary for our async bridges. Follow up PRs will include the following things: 1. Bootstrap support 2. Protocol negotiation support 3. Example with documentation * Add AsyncSequence bridge to NIOAsyncChannelOutboundWriter * Code review * Prefix temporary spi public method * Rename writeAndFlush to write --- Package.swift | 2 +- .../NIOCore/AsyncChannel/AsyncChannel.swift | 133 +++++ .../AsyncChannelInboundStream.swift | 90 +++ ...ncChannelInboundStreamChannelHandler.swift | 252 ++++++++ .../AsyncChannelOutboundWriter.swift | 93 +++ .../AsyncChannelOutboundWriterHandler.swift | 175 ++++++ .../NIOCore/AsyncChannel/CloseRatchet.swift | 94 +++ .../AsyncChannel/AsyncChannelTests.swift | 561 ++++++++++++++++++ Tests/NIOCoreTests/XCTest+AsyncAwait.swift | 19 + 9 files changed, 1418 insertions(+), 1 deletion(-) create mode 100644 Sources/NIOCore/AsyncChannel/AsyncChannel.swift create mode 100644 Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift create mode 100644 Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift create mode 100644 Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift create mode 100644 Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift create mode 100644 Sources/NIOCore/AsyncChannel/CloseRatchet.swift create mode 100644 Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift diff --git a/Package.swift b/Package.swift index 37b8e5d2..d010723a 100644 --- a/Package.swift +++ b/Package.swift @@ -100,7 +100,7 @@ var targets: [PackageDescription.Target] = [ .executableTarget(name: "NIOAsyncAwaitDemo", dependencies: ["NIOPosix", "NIOCore", "NIOHTTP1"]), .testTarget(name: "NIOCoreTests", - dependencies: ["NIOCore", "NIOEmbedded", "NIOFoundationCompat"]), + dependencies: ["NIOCore", "NIOEmbedded", "NIOFoundationCompat", swiftAtomics]), .testTarget(name: "NIOEmbeddedTests", dependencies: ["NIOConcurrencyHelpers", "NIOCore", "NIOEmbedded"]), .testTarget(name: "NIOPosixTests", diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift new file mode 100644 index 00000000..9bac4d9f --- /dev/null +++ b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift @@ -0,0 +1,133 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022-2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if swift(>=5.6) +/// Wraps a NIO ``Channel`` object into a form suitable for use in Swift Concurrency. +/// +/// ``NIOAsyncChannel`` abstracts the notion of a NIO ``Channel`` into something that +/// can safely be used in a structured concurrency context. In particular, this exposes +/// the following functionality: +/// +/// - reads are presented as an `AsyncSequence` +/// - writes can be written to with async functions on a writer, providing backpressure +/// - channels can be closed seamlessly +/// +/// This type does not replace the full complexity of NIO's ``Channel``. In particular, it +/// does not expose the following functionality: +/// +/// - user events +/// - traditional NIO backpressure such as writability signals and the ``Channel/read()`` call +/// +/// Users are encouraged to separate their ``ChannelHandler``s into those that implement +/// protocol-specific logic (such as parsers and encoders) and those that implement business +/// logic. Protocol-specific logic should be implemented as a ``ChannelHandler``, while business +/// logic should use ``NIOAsyncChannel`` to consume and produce data to the network. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +@_spi(AsyncChannel) +public final class NIOAsyncChannel: Sendable { + /// The underlying channel being wrapped by this ``NIOAsyncChannel``. + @_spi(AsyncChannel) + public let channel: Channel + /// The stream of inbound messages. + @_spi(AsyncChannel) + public let inboundStream: NIOAsyncChannelInboundStream + /// The writer for writing outbound messages. + @_spi(AsyncChannel) + public let outboundWriter: NIOAsyncChannelOutboundWriter + + /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel``. + /// + /// - Important: This **must** be called on the channel's event loop otherwise this init will crash. This is necessary because + /// we must install the handlers before any other event in the pipeline happens otherwise we might drop reads. + /// + /// - Parameters: + /// - channel: The ``Channel`` to wrap. + /// - backpressureStrategy: The backpressure strategy of the ``NIOAsyncChannel/inboundStream``. + /// - isOutboundHalfClosureEnabled: If outbound half closure should be enabled. Outbound half closure is triggered once + /// the ``NIOAsyncChannelWriter`` is either finished or deinitialized. + /// - inboundType: The ``NIOAsyncChannel/inboundStream`` message's type. + /// - outboundType: The ``NIOAsyncChannel/outboundWriter`` message's type. + @inlinable + @_spi(AsyncChannel) + public init( + synchronouslyWrapping channel: Channel, + backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + isOutboundHalfClosureEnabled: Bool = true, + inboundType: Inbound.Type = Inbound.self, + outboundType: Outbound.Type = Outbound.self + ) throws { + channel.eventLoop.preconditionInEventLoop() + self.channel = channel + (self.inboundStream, self.outboundWriter) = try channel._syncAddAsyncHandlers( + backpressureStrategy: backpressureStrategy, + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled + ) + } + + /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel`` where the outbound type is `Never`. + /// + /// This initializer will finish the ``NIOAsyncChannel/outboundWriter`` immediately. + /// + /// - Important: This **must** be called on the channel's event loop otherwise this init will crash. This is necessary because + /// we must install the handlers before any other event in the pipeline happens otherwise we might drop reads. + /// + /// - Parameters: + /// - channel: The ``Channel`` to wrap. + /// - backpressureStrategy: The backpressure strategy of the ``NIOAsyncChannel/inboundStream``. + /// - isOutboundHalfClosureEnabled: If outbound half closure should be enabled. Outbound half closure is triggered once + /// the ``NIOAsyncChannelWriter`` is either finished or deinitialized. + /// - inboundType: The ``NIOAsyncChannel/inboundStream`` message's type. + @inlinable + @_spi(AsyncChannel) + public init( + synchronouslyWrapping channel: Channel, + backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + isOutboundHalfClosureEnabled: Bool = true, + inboundType: Inbound.Type = Inbound.self + ) throws where Outbound == Never { + channel.eventLoop.preconditionInEventLoop() + self.channel = channel + (self.inboundStream, self.outboundWriter) = try channel._syncAddAsyncHandlers( + backpressureStrategy: backpressureStrategy, + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled + ) + + self.outboundWriter.finish() + } +} + +extension Channel { + // TODO: We need to remove the public and spi here once we make the AsyncChannel methods public + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + @inlinable + @_spi(AsyncChannel) + public func _syncAddAsyncHandlers( + backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + isOutboundHalfClosureEnabled: Bool + ) throws -> (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) { + self.eventLoop.assertInEventLoop() + + let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled) + let inboundStream = try NIOAsyncChannelInboundStream( + channel: self, + backpressureStrategy: backpressureStrategy, + closeRatchet: closeRatchet + ) + let writer = try NIOAsyncChannelOutboundWriter( + channel: self, + closeRatchet: closeRatchet + ) + return (inboundStream, writer) + } +} +#endif diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift new file mode 100644 index 00000000..6d25b910 --- /dev/null +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift @@ -0,0 +1,90 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022-2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if swift(>=5.6) +/// The inbound message asynchronous sequence of a ``NIOAsyncChannel``. +/// +/// This is a unicast async sequence that allows a single iterator to be created. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +@_spi(AsyncChannel) +public struct NIOAsyncChannelInboundStream: Sendable { + @usableFromInline + typealias Producer = NIOThrowingAsyncSequenceProducer.Delegate> + + /// The underlying async sequence. + @usableFromInline let _producer: Producer + + @inlinable + init( + channel: Channel, + backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + closeRatchet: CloseRatchet + ) throws { + channel.eventLoop.preconditionInEventLoop() + let handler = NIOAsyncChannelInboundStreamChannelHandler( + eventLoop: channel.eventLoop, + closeRatchet: closeRatchet + ) + let strategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark + + if let userProvided = backpressureStrategy { + strategy = userProvided + } else { + // Default strategy. These numbers are fairly arbitrary, but they line up with the default value of + // maxMessagesPerRead. + strategy = .init(lowWatermark: 2, highWatermark: 10) + } + + let sequence = Producer.makeSequence( + backPressureStrategy: strategy, + delegate: NIOAsyncChannelInboundStreamChannelHandler.Delegate(handler: handler) + ) + handler.source = sequence.source + try channel.pipeline.syncOperations.addHandler(handler) + self._producer = sequence.sequence + } +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NIOAsyncChannelInboundStream: AsyncSequence { + @_spi(AsyncChannel) + public typealias Element = Inbound + + @_spi(AsyncChannel) + public struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline var _iterator: Producer.AsyncIterator + + @inlinable + init(_ iterator: Producer.AsyncIterator) { + self._iterator = iterator + } + + @inlinable @_spi(AsyncChannel) + public mutating func next() async throws -> Element? { + return try await self._iterator.next() + } + } + + @inlinable + @_spi(AsyncChannel) + public func makeAsyncIterator() -> AsyncIterator { + return AsyncIterator(self._producer.makeAsyncIterator()) + } +} + +/// The ``NIOAsyncChannelInboundStream/AsyncIterator`` MUST NOT be shared across `Task`s. With marking this as +/// unavailable we are explicitly declaring this. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +@available(*, unavailable) +extension NIOAsyncChannelInboundStream.AsyncIterator: Sendable {} +#endif diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift new file mode 100644 index 00000000..c581b8c1 --- /dev/null +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift @@ -0,0 +1,252 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022-2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if swift(>=5.6) +/// A ``ChannelHandler`` that is used to transform the inbound portion of a NIO +/// ``Channel`` into an asynchronous sequence that supports back-pressure. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +@usableFromInline +internal final class NIOAsyncChannelInboundStreamChannelHandler: ChannelDuplexHandler { + @usableFromInline + enum _ProducingState { + // Not .stopProducing + case keepProducing + + // .stopProducing but not read() + case producingPaused + + // .stopProducing and read() + case producingPausedWithOutstandingRead + } + + @usableFromInline + typealias OutboundIn = Any + + @usableFromInline + typealias OutboundOut = Any + + @usableFromInline + typealias Source = NIOThrowingAsyncSequenceProducer< + InboundIn, + Error, + NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark, + NIOAsyncChannelInboundStreamChannelHandler.Delegate + >.Source + + /// The source of the asynchronous sequence. + @usableFromInline + var source: Source? + + /// The channel handler's context. + @usableFromInline + var context: ChannelHandlerContext? + + /// An array of reads which will be yielded to the source with the next channel read complete. + @usableFromInline + var buffer: [InboundIn] = [] + + /// The current producing state. + @usableFromInline + var producingState: _ProducingState = .keepProducing + + /// The event loop. + @usableFromInline + let eventLoop: EventLoop + + /// The shared `CloseRatchet` between this handler and the writer handler. + @usableFromInline + let closeRatchet: CloseRatchet + + @inlinable + init(eventLoop: EventLoop, closeRatchet: CloseRatchet) { + self.eventLoop = eventLoop + self.closeRatchet = closeRatchet + } + + @inlinable + func handlerAdded(context: ChannelHandlerContext) { + self.context = context + } + + @inlinable + func handlerRemoved(context: ChannelHandlerContext) { + self._finishSource(context: context) + self.context = nil + } + + @inlinable + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.buffer.append(self.unwrapInboundIn(data)) + + // We forward on reads here to enable better channel composition. + context.fireChannelRead(data) + } + + @inlinable + func channelReadComplete(context: ChannelHandlerContext) { + self._deliverReads(context: context) + context.fireChannelReadComplete() + } + + @inlinable + func channelInactive(context: ChannelHandlerContext) { + self._finishSource(context: context) + context.fireChannelInactive() + } + + @inlinable + func errorCaught(context: ChannelHandlerContext, error: Error) { + self._finishSource(with: error, context: context) + context.fireErrorCaught(error) + } + + @inlinable + func read(context: ChannelHandlerContext) { + switch self.producingState { + case .keepProducing: + context.read() + case .producingPaused: + self.producingState = .producingPausedWithOutstandingRead + case .producingPausedWithOutstandingRead: + break + } + } + + @inlinable + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + switch event { + case ChannelEvent.inputClosed: + self._finishSource(context: context) + default: + break + } + + context.fireUserInboundEventTriggered(event) + } + + @inlinable + func _finishSource(with error: Error? = nil, context: ChannelHandlerContext) { + guard let source = self.source else { + return + } + + // We need to deliver the reads first to buffer them in the source. + self._deliverReads(context: context) + + if let error = error { + source.finish(error) + } else { + source.finish() + } + + // We can nil the source here, as we're no longer going to use it. + self.source = nil + } + + @inlinable + func _deliverReads(context: ChannelHandlerContext) { + if self.buffer.isEmpty { + return + } + + guard let source = self.source else { + self.buffer.removeAll() + return + } + + let result = source.yield(contentsOf: self.buffer) + switch result { + case .produceMore, .dropped: + break + case .stopProducing: + if self.producingState != .producingPausedWithOutstandingRead { + self.producingState = .producingPaused + } + } + self.buffer.removeAll(keepingCapacity: true) + } +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NIOAsyncChannelInboundStreamChannelHandler { + @inlinable + func _didTerminate() { + self.eventLoop.preconditionInEventLoop() + self.source = nil + + // Wedges the read open forever, we'll never read again. + self.producingState = .producingPausedWithOutstandingRead + + switch self.closeRatchet.closeRead() { + case .nothing: + break + + case .close: + self.context?.close(promise: nil) + } + } + + @inlinable + func _produceMore() { + self.eventLoop.preconditionInEventLoop() + + switch self.producingState { + case .producingPaused: + self.producingState = .keepProducing + + case .producingPausedWithOutstandingRead: + self.producingState = .keepProducing + self.context?.read() + + case .keepProducing: + break + } + } +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NIOAsyncChannelInboundStreamChannelHandler { + @usableFromInline + struct Delegate: @unchecked Sendable, NIOAsyncSequenceProducerDelegate { + @usableFromInline + let eventLoop: EventLoop + + @usableFromInline + let handler: NIOAsyncChannelInboundStreamChannelHandler + + @inlinable + init(handler: NIOAsyncChannelInboundStreamChannelHandler) { + self.eventLoop = handler.eventLoop + self.handler = handler + } + + @inlinable + func didTerminate() { + self.eventLoop.execute { + self.handler._didTerminate() + } + } + + @inlinable + func produceMore() { + self.eventLoop.execute { + self.handler._produceMore() + } + } + } +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +@available(*, unavailable) +extension NIOAsyncChannelInboundStreamChannelHandler: Sendable {} +#endif diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift new file mode 100644 index 00000000..0831687b --- /dev/null +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift @@ -0,0 +1,93 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022-2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if swift(>=5.6) +/// A ``NIOAsyncChannelWriter`` is used to write and flush new outbound messages in a channel. +/// +/// The writer acts as a bridge between the Concurrency and NIO world. It allows to write and flush messages into the +/// underlying ``Channel``. Furthermore, it respects back-pressure of the channel by suspending the calls to write until +/// the channel becomes writable again. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +@_spi(AsyncChannel) +public struct NIOAsyncChannelOutboundWriter: Sendable { + @usableFromInline + typealias _Writer = NIOAsyncChannelOutboundWriterHandler.Writer + + @usableFromInline + let _outboundWriter: _Writer + + @inlinable + init( + channel: Channel, + closeRatchet: CloseRatchet + ) throws { + let handler = NIOAsyncChannelOutboundWriterHandler( + eventLoop: channel.eventLoop, + closeRatchet: closeRatchet + ) + let writer = _Writer.makeWriter( + elementType: OutboundOut.self, + isWritable: true, + delegate: .init(handler: handler) + ) + handler.sink = writer.sink + + try channel.pipeline.syncOperations.addHandler(handler) + + self._outboundWriter = writer.writer + } + + @inlinable + init(outboundWriter: NIOAsyncChannelOutboundWriterHandler.Writer) { + self._outboundWriter = outboundWriter + } + + /// Send a write into the ``ChannelPipeline`` and flush it right away. + /// + /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. + @inlinable + @_spi(AsyncChannel) + public func write(_ data: OutboundOut) async throws { + try await self._outboundWriter.yield(data) + } + + /// Send a sequence of writes into the ``ChannelPipeline`` and flush them right away. + /// + /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. + @inlinable + @_spi(AsyncChannel) + public func write(contentsOf sequence: Writes) async throws where Writes.Element == OutboundOut { + try await self._outboundWriter.yield(contentsOf: sequence) + } + + /// Send a sequence of writes into the ``ChannelPipeline`` and flush them right away. + /// + /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. + @inlinable + @_spi(AsyncChannel) + public func write(contentsOf sequence: Writes) async throws where Writes.Element == OutboundOut { + for try await data in sequence { + try await self._outboundWriter.yield(data) + } + } + + /// Finishes the writer. + /// + /// This might trigger a half closure if the ``NIOAsyncChannel`` was configured to support it. + @_spi(AsyncChannel) + public func finish() { + self._outboundWriter.finish() + } +} +#endif diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift new file mode 100644 index 00000000..bbbf8dec --- /dev/null +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift @@ -0,0 +1,175 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022-2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if swift(>=5.6) +import DequeModule + +/// A ``ChannelHandler`` that is used to write the outbound portion of a NIO +/// ``Channel`` from Swift Concurrency with back-pressure support. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +@usableFromInline +internal final class NIOAsyncChannelOutboundWriterHandler: ChannelDuplexHandler { + @usableFromInline typealias InboundIn = Any + @usableFromInline typealias InboundOut = Any + @usableFromInline typealias OutboundIn = Any + @usableFromInline typealias OutboundOut = OutboundOut + + @usableFromInline + typealias Writer = NIOAsyncWriter< + OutboundOut, + NIOAsyncChannelOutboundWriterHandler.Delegate + > + + @usableFromInline + typealias Sink = Writer.Sink + + /// The sink of the ``NIOAsyncWriter``. + @usableFromInline + var sink: Sink? + + /// The channel handler context. + @usableFromInline + var context: ChannelHandlerContext? + + /// The event loop. + @usableFromInline + let eventLoop: EventLoop + + /// The shared `CloseRatchet` between this handler and the inbound stream handler. + @usableFromInline + let closeRatchet: CloseRatchet + + @inlinable + init( + eventLoop: EventLoop, + closeRatchet: CloseRatchet + ) { + self.eventLoop = eventLoop + self.closeRatchet = closeRatchet + } + + @inlinable + func _didYield(sequence: Deque) { + // This is always called from an async context, so we must loop-hop. + // Because we always loop-hop, we're always at the top of a stack frame. As this + // is the only source of writes for us, and as this channel handler doesn't implement + // func write(), we cannot possibly re-entrantly write. That means we can skip many of the + // awkward re-entrancy protections NIO usually requires, and can safely just do an iterative + // write. + self.eventLoop.preconditionInEventLoop() + guard let context = self.context else { + // Already removed from the channel by now, we can stop. + return + } + + self._doOutboundWrites(context: context, writes: sequence) + } + + @inlinable + func _didTerminate(error: Error?) { + self.eventLoop.preconditionInEventLoop() + + switch self.closeRatchet.closeWrite() { + case .nothing: + break + + case .closeOutput: + self.context?.close(mode: .output, promise: nil) + + case .close: + self.context?.close(promise: nil) + } + + self.sink = nil + } + + @inlinable + func _doOutboundWrites(context: ChannelHandlerContext, writes: Deque) { + for write in writes { + context.write(self.wrapOutboundOut(write), promise: nil) + } + + context.flush() + } + + @inlinable + func handlerAdded(context: ChannelHandlerContext) { + self.context = context + } + + @inlinable + func handlerRemoved(context: ChannelHandlerContext) { + self.context = nil + self.sink = nil + } + + @inlinable + func errorCaught(context: ChannelHandlerContext, error: Error) { + self.sink?.finish(error: error) + context.fireErrorCaught(error) + } + + @inlinable + func channelInactive(context: ChannelHandlerContext) { + self.sink?.finish() + context.fireChannelInactive() + } + + @inlinable + func channelWritabilityChanged(context: ChannelHandlerContext) { + self.sink?.setWritability(to: context.channel.isWritable) + context.fireChannelWritabilityChanged() + } +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NIOAsyncChannelOutboundWriterHandler { + @usableFromInline + struct Delegate: @unchecked Sendable, NIOAsyncWriterSinkDelegate { + @usableFromInline + typealias Element = OutboundOut + + @usableFromInline + let eventLoop: EventLoop + + @usableFromInline + let handler: NIOAsyncChannelOutboundWriterHandler + + @inlinable + init(handler: NIOAsyncChannelOutboundWriterHandler) { + self.eventLoop = handler.eventLoop + self.handler = handler + } + + @inlinable + func didYield(contentsOf sequence: Deque) { + // This always called from an async context, so we must loop-hop. + self.eventLoop.execute { + self.handler._didYield(sequence: sequence) + } + } + + @inlinable + func didTerminate(error: Error?) { + // This always called from an async context, so we must loop-hop. + self.eventLoop.execute { + self.handler._didTerminate(error: error) + } + } + } +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +@available(*, unavailable) +extension NIOAsyncChannelOutboundWriterHandler: Sendable {} +#endif diff --git a/Sources/NIOCore/AsyncChannel/CloseRatchet.swift b/Sources/NIOCore/AsyncChannel/CloseRatchet.swift new file mode 100644 index 00000000..11845f28 --- /dev/null +++ b/Sources/NIOCore/AsyncChannel/CloseRatchet.swift @@ -0,0 +1,94 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022-2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if swift(>=5.6) +/// A helper type that lets ``NIOAsyncChannelAdapterHandler`` and ``NIOAsyncChannelWriterHandler`` collude +/// to ensure that the ``Channel`` they share is closed appropriately. +/// +/// The strategy of this type is that it keeps track of which side has closed, so that the handlers can work out +/// which of them was "last", in order to arrange closure. +@usableFromInline +final class CloseRatchet { + @usableFromInline + enum State { + case notClosed(isOutboundHalfClosureEnabled: Bool) + case readClosed + case writeClosed + case bothClosed + + @inlinable + mutating func closeRead() -> CloseReadAction { + switch self { + case .notClosed: + self = .readClosed + return .nothing + case .writeClosed: + self = .bothClosed + return .close + case .readClosed, .bothClosed: + preconditionFailure("Duplicate read closure") + } + } + + @inlinable + mutating func closeWrite() -> CloseWriteAction { + switch self { + case .notClosed(let isOutboundHalfClosureEnabled): + self = .writeClosed + + if isOutboundHalfClosureEnabled { + return .closeOutput + } else { + return .nothing + } + case .readClosed: + self = .bothClosed + return .close + case .writeClosed, .bothClosed: + preconditionFailure("Duplicate write closure") + } + } + } + + @usableFromInline + var _state: State + + @inlinable + init(isOutboundHalfClosureEnabled: Bool) { + self._state = .notClosed(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled) + } + + @usableFromInline + enum CloseReadAction { + case nothing + case close + } + + @inlinable + func closeRead() -> CloseReadAction { + return self._state.closeRead() + } + + @usableFromInline + enum CloseWriteAction { + case nothing + case close + case closeOutput + } + + @inlinable + func closeWrite() -> CloseWriteAction { + return self._state.closeWrite() + } +} +#endif diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift new file mode 100644 index 00000000..8a06c296 --- /dev/null +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -0,0 +1,561 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import Atomics +import NIOConcurrencyHelpers +@_spi(AsyncChannel) @testable import NIOCore +import NIOEmbedded +import XCTest + +final class AsyncChannelTests: XCTestCase { + func testAsyncChannelBasicFunctionality() { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: String.self, outboundType: Never.self) + } + + var iterator = wrapped.inboundStream.makeAsyncIterator() + try await channel.writeInbound("hello") + let firstRead = try await iterator.next() + XCTAssertEqual(firstRead, "hello") + + try await channel.writeInbound("world") + let secondRead = try await iterator.next() + XCTAssertEqual(secondRead, "world") + + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) + } + + let thirdRead = try await iterator.next() + XCTAssertNil(thirdRead) + + try await channel.close() + } + #endif + } + + func testAsyncChannelBasicWrites() { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: Never.self, outboundType: String.self) + } + + try await wrapped.outboundWriter.write("hello") + try await wrapped.outboundWriter.write("world") + + let firstRead = try await channel.waitForOutboundWrite(as: String.self) + let secondRead = try await channel.waitForOutboundWrite(as: String.self) + + XCTAssertEqual(firstRead, "hello") + XCTAssertEqual(secondRead, "world") + + try await channel.close() + } + #endif + } + + func testDroppingTheWriterClosesTheWriteSideOfTheChannel() { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + let closeRecorder = CloseRecorder() + try await channel.pipeline.addHandler(closeRecorder) + + let inboundReader: NIOAsyncChannelInboundStream + + do { + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: Never.self, outboundType: Never.self) + } + inboundReader = wrapped.inboundStream + + try await channel.testingEventLoop.executeInContext { + XCTAssertEqual(0, closeRecorder.outboundCloses) + } + } + + try await channel.testingEventLoop.executeInContext { + XCTAssertEqual(1, closeRecorder.outboundCloses) + } + + // Just use this to keep the inbound reader alive. + withExtendedLifetime(inboundReader) {} + channel.close(promise: nil) + } + #endif + } + + func testDroppingTheWriterDoesntCloseTheWriteSideOfTheChannelIfHalfClosureIsDisabled() { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + let closeRecorder = CloseRecorder() + try await channel.pipeline.addHandler(closeRecorder) + + let inboundReader: NIOAsyncChannelInboundStream + + do { + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(synchronouslyWrapping: channel, isOutboundHalfClosureEnabled: false, inboundType: Never.self, outboundType: Never.self) + } + inboundReader = wrapped.inboundStream + + try await channel.testingEventLoop.executeInContext { + XCTAssertEqual(0, closeRecorder.outboundCloses) + } + } + + try await channel.testingEventLoop.executeInContext { + XCTAssertEqual(0, closeRecorder.outboundCloses) + } + + // Just use this to keep the inbound reader alive. + withExtendedLifetime(inboundReader) {} + channel.close(promise: nil) + } + #endif + } + + func testDroppingTheWriterFirstLeadsToChannelClosureWhenReaderIsAlsoDropped() { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + let closeRecorder = CloseRecorder() + try await channel.pipeline.addHandler(CloseSuppressor()) + try await channel.pipeline.addHandler(closeRecorder) + + do { + let inboundReader: NIOAsyncChannelInboundStream + + do { + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: Never.self, outboundType: Never.self) + } + inboundReader = wrapped.inboundStream + + try await channel.testingEventLoop.executeInContext { + XCTAssertEqual(0, closeRecorder.allCloses) + } + } + + // First we see half-closure. + try await channel.testingEventLoop.executeInContext { + XCTAssertEqual(1, closeRecorder.allCloses) + } + + // Just use this to keep the inbound reader alive. + withExtendedLifetime(inboundReader) {} + } + + // Now the inbound reader is dead, we see full closure. + try await channel.testingEventLoop.executeInContext { + XCTAssertEqual(2, closeRecorder.allCloses) + } + + try await channel.closeIgnoringSuppression() + } + #endif + } + + func testDroppingEverythingClosesTheChannel() { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + let closeRecorder = CloseRecorder() + try await channel.pipeline.addHandler(CloseSuppressor()) + try await channel.pipeline.addHandler(closeRecorder) + + do { + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(synchronouslyWrapping: channel, isOutboundHalfClosureEnabled: false, inboundType: Never.self, outboundType: Never.self) + } + + try await channel.testingEventLoop.executeInContext { + XCTAssertEqual(0, closeRecorder.allCloses) + } + + // Just use this to keep the wrapper alive until here. + withExtendedLifetime(wrapped) {} + } + + // Now that everything is dead, we see full closure. + try await channel.testingEventLoop.executeInContext { + XCTAssertEqual(1, closeRecorder.allCloses) + } + + try await channel.closeIgnoringSuppression() + } + #endif + } + + func testReadsArePropagated() { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: String.self, outboundType: Never.self) + } + + try await channel.writeInbound("hello") + let propagated = try await channel.readInbound(as: String.self) + XCTAssertEqual(propagated, "hello") + + try await channel.close().get() + + let reads = try await Array(wrapped.inboundStream) + XCTAssertEqual(reads, ["hello"]) + } + #endif + } + + func testErrorsArePropagatedButAfterReads() { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: String.self, outboundType: Never.self) + } + + try await channel.writeInbound("hello") + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireErrorCaught(TestError.bang) + } + + var iterator = wrapped.inboundStream.makeAsyncIterator() + let first = try await iterator.next() + XCTAssertEqual(first, "hello") + + try await XCTAssertThrowsError(await iterator.next()) { error in + XCTAssertEqual(error as? TestError, .bang) + } + } + #endif + } + + func testErrorsArePropagatedToWriters() { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: Never.self, outboundType: String.self) + } + + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireErrorCaught(TestError.bang) + } + + try await XCTAssertThrowsError(await wrapped.outboundWriter.write("hello")) { error in + XCTAssertEqual(error as? TestError, .bang) + } + } + #endif + } + + func testChannelBecomingNonWritableDelaysWriters() { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: Never.self, outboundType: String.self) + } + + try await channel.testingEventLoop.executeInContext { + channel.isWritable = false + channel.pipeline.fireChannelWritabilityChanged() + } + + let lock = NIOLockedValueBox(false) + + await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await wrapped.outboundWriter.write("hello") + lock.withLockedValue { + XCTAssertTrue($0) + } + } + + group.addTask { + // 10ms sleep before we wake the thing up + try await Task.sleep(nanoseconds: 10_000_000) + + try await channel.testingEventLoop.executeInContext { + channel.isWritable = true + lock.withLockedValue { $0 = true } + channel.pipeline.fireChannelWritabilityChanged() + } + } + } + + try await channel.close().get() + } + #endif + } + + func testBufferDropsReadsIfTheReaderIsGone() { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + try await channel.pipeline.addHandler(CloseSuppressor()).get() + do { + // Create the NIOAsyncChannel, then drop it. The handler will still be in the pipeline. + _ = try await channel.testingEventLoop.executeInContext { + _ = try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: Sentinel.self, outboundType: Never.self) + } + } + + weak var sentinel: Sentinel? + do { + let strongSentinel: Sentinel? = Sentinel() + sentinel = strongSentinel! + try await XCTAsyncAssertNotNil(await channel.pipeline.handler(type: NIOAsyncChannelInboundStreamChannelHandler.self).get()) + try await channel.writeInbound(strongSentinel!) + _ = try await channel.readInbound(as: Sentinel.self) + } + + XCTAssertNil(sentinel) + + try await channel.closeIgnoringSuppression() + } + #endif + } + + func testManagingBackpressure() { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + let readCounter = ReadCounter() + try await channel.pipeline.addHandler(readCounter) + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(synchronouslyWrapping: channel, backpressureStrategy: .init(lowWatermark: 2, highWatermark: 4), inboundType: Void.self, outboundType: Never.self) + } + + // Attempt to read. This should succeed an arbitrary number of times. + XCTAssertEqual(readCounter.readCount, 0) + try await channel.testingEventLoop.executeInContext { + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() + } + XCTAssertEqual(readCounter.readCount, 3) + + // Push 3 elements into the buffer. Reads continue to work. + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelReadComplete() + + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() + } + XCTAssertEqual(readCounter.readCount, 6) + + // Add one more element into the buffer. This should flip our backpressure mode, and the reads should now be delayed. + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelReadComplete() + + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() + } + XCTAssertEqual(readCounter.readCount, 6) + + // More elements don't help. + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelReadComplete() + + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() + } + XCTAssertEqual(readCounter.readCount, 6) + + // Now consume three elements from the pipeline. This should not unbuffer the read, as 3 elements remain. + var reader = wrapped.inboundStream.makeAsyncIterator() + for _ in 0..<3 { + try await XCTAsyncAssertNotNil(await reader.next()) + } + await channel.testingEventLoop.run() + XCTAssertEqual(readCounter.readCount, 6) + + // Removing the next element should trigger an automatic read. + try await XCTAsyncAssertNotNil(await reader.next()) + await channel.testingEventLoop.run() + XCTAssertEqual(readCounter.readCount, 7) + + // Reads now work again, even if more data arrives. + try await channel.testingEventLoop.executeInContext { + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() + + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelReadComplete() + + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() + } + XCTAssertEqual(readCounter.readCount, 13) + + // The next reads arriving pushes us past the limit again. + // This time we won't read. + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelReadComplete() + } + XCTAssertEqual(readCounter.readCount, 13) + + // This time we'll consume 4 more elements, and we won't find a read at all. + for _ in 0..<4 { + try await XCTAsyncAssertNotNil(await reader.next()) + } + await channel.testingEventLoop.run() + XCTAssertEqual(readCounter.readCount, 13) + + // But the next reads work fine. + try await channel.testingEventLoop.executeInContext { + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() + } + XCTAssertEqual(readCounter.readCount, 16) + } + #endif + } + + func testCanWrapAChannelSynchronously() throws { + #if swift(>=5.6) + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest(timeout: 5) { + let channel = NIOAsyncTestingChannel() + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: String.self, outboundType: String.self) + } + + var iterator = wrapped.inboundStream.makeAsyncIterator() + try await channel.writeInbound("hello") + let firstRead = try await iterator.next() + XCTAssertEqual(firstRead, "hello") + + try await wrapped.outboundWriter.write("world") + let write = try await channel.waitForOutboundWrite(as: String.self) + XCTAssertEqual(write, "world") + + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) + } + + let secondRead = try await iterator.next() + XCTAssertNil(secondRead) + + try await channel.close() + } + #endif + } +} + +// This is unchecked Sendable since we only call this in the testing eventloop +private final class CloseRecorder: ChannelOutboundHandler, @unchecked Sendable { + typealias OutboundIn = Any + typealias outbound = Any + + var outboundCloses = 0 + + var allCloses = 0 + + init() {} + + func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { + self.allCloses += 1 + + if case .output = mode { + self.outboundCloses += 1 + } + + context.close(mode: mode, promise: promise) + } +} + +private final class CloseSuppressor: ChannelOutboundHandler, RemovableChannelHandler { + typealias OutboundIn = Any + typealias outbound = Any + + func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { + // We drop the close here. + promise?.fail(TestError.bang) + } +} + +extension NIOAsyncTestingChannel { + fileprivate func closeIgnoringSuppression() async throws { + try await self.pipeline.context(handlerType: CloseSuppressor.self).flatMap { + self.pipeline.removeHandler(context: $0) + }.flatMap { + self.close() + }.get() + } +} + +private final class ReadCounter: ChannelOutboundHandler, @unchecked Sendable { + typealias OutboundIn = Any + typealias outbound = Any + + private let _readCount = ManagedAtomic(0) + + var readCount: Int { + self._readCount.load(ordering: .acquiring) + } + + func read(context: ChannelHandlerContext) { + self._readCount.wrappingIncrement(ordering: .releasing) + context.read() + } +} + +private enum TestError: Error { + case bang +} + +extension Array { + fileprivate init(_ sequence: AS) async throws where AS.Element == Self.Element { + self = [] + + for try await nextElement in sequence { + self.append(nextElement) + } + } +} + +private final class Sentinel: Sendable {} diff --git a/Tests/NIOCoreTests/XCTest+AsyncAwait.swift b/Tests/NIOCoreTests/XCTest+AsyncAwait.swift index 4f790411..222952fa 100644 --- a/Tests/NIOCoreTests/XCTest+AsyncAwait.swift +++ b/Tests/NIOCoreTests/XCTest+AsyncAwait.swift @@ -116,3 +116,22 @@ internal func XCTAssertNoThrowWithResult( return nil } +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +internal func XCTAsyncAssertNotNil( + _ expression: @autoclosure () async throws -> Any?, + file: StaticString = #filePath, + line: UInt = #line +) async rethrows { + let result = try await expression() + XCTAssertNotNil(result, file: file, line: line) +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +internal func XCTAsyncAssertNil( + _ expression: @autoclosure () async throws -> Any?, + file: StaticString = #filePath, + line: UInt = #line +) async rethrows { + let result = try await expression() + XCTAssertNil(result, file: file, line: line) +}