Add NIOAsyncChannel based connect methods to ClientBootstrap (#2437)

* Add NIOAsyncChannel based connect methods to ClientBootstrap

# Motivation
In my previous PR, I added new `bind` methods to `ServerBootstrap` that vend `NIOAsyncChannel` or support an async protocol negotiation. This PR focuses on adding new `connect` methods to `ClientBootstrap` which offer the same functionality.

# Modification
This PR adds new `connect` methods that either vend a `NIOAsyncChannel` or an asynchronous protocol negotiation result. To make this work I had to change the `HappyEyeballs` resolver so that it can return a generic value on resolving. Lastly, I adapted the bootstrap tests to use the new `ClientBootstrap` capabilities which now demonstrate a client/server protocol negotiation dance.

# Result
We can now bootstrap TCP clients with `NIOAsyncChannel`s

* Reduce code duplication

* Create a new set of APIs to tunnel an arbitrary Sendable payload through the inits

* Pass EL to closure

* Fix documentation
This commit is contained in:
Franz Busch 2023-06-06 11:36:53 +01:00 committed by GitHub
parent 6213ba7a06
commit 46c0538253
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 895 additions and 171 deletions

View File

@ -142,7 +142,25 @@ Afterwards, we handle each inbound connection in separate child tasks and echo t
Normal task groups will result in a memory leak since they do not reap their child tasks automatically.
#### ClientBootstrap
> Important: Support for `ClientBootstrap` with `NIOAsyncChannel` hasn't landed yet.
The client bootstrap is used to create a new TCP based client. Let's take a look at the new
`NIOAsyncChannel` based connect methods.
```swift
let clientChannel = try await ClientBootstrap(group: eventLoopGroup)
.connect(
host: "127.0.0.1",
port: 0,
channelInboundType: ByteBuffer.self,
channelOutboundType: ByteBuffer.self
)
clientChannel.outboundWriter.write(ByteBuffer(string: "hello"))
for try await inboundData in clientChannel.inboundStream {
print(inboundData)
}
```
#### DatagramBootstrap
> Important: Support for `DatagramBootstrap` with `NIOAsyncChannel` hasn't landed yet.
@ -158,7 +176,7 @@ To solve the problem of protocol negotiation, NIO introduced a new ``ChannelHand
that is completed once the handler is finished with protocol negotiation. In the successful case,
the future can either indicate that protocol negotiation is fully done by returning `NIOProtocolNegotiationResult/finished(_:)` or
indicate that further protocol negotiation needs to be done by returning `NIOProtocolNegotiationResult/deferredResult(_:)`.
Additionally, the various bootstraps provide another set of `bind()` methods that handle protocol negotiation.
Additionally, the various bootstraps provide another set of `bind()`/`connect()` methods that handle protocol negotiation.
Let's walk through how to setup a `ServerBootstrap` with protocol negotiation.
First, we have to define our negotiation result. For this example, we are negotiating between a

View File

@ -892,7 +892,7 @@ extension ServerBootstrap {
.flatMap { handler -> EventLoopFuture<NIOProtocolNegotiationResult<Handler.NegotiationResult>> in
handler.protocolNegotiationResult
}.flatMap { result in
ServerBootstrap.waitForFinalResult(result, eventLoop: eventLoop)
result.resolve(on: eventLoop)
}.flatMapErrorThrowing { error in
channel.pipeline.fireErrorCaught(error)
channel.close(promise: nil)
@ -913,22 +913,6 @@ extension ServerBootstrap {
$0
}.get()
}
/// This method recursively waits for the final result of protocol negotiation
static func waitForFinalResult<NegotiationResult>(
_ result: NIOProtocolNegotiationResult<NegotiationResult>,
eventLoop: EventLoop
) -> EventLoopFuture<NegotiationResult> {
switch result {
case .finished(let negotiationResult):
return eventLoop.makeSucceededFuture(negotiationResult)
case .deferredResult(let future):
return future.flatMap { result in
return waitForFinalResult(result, eventLoop: eventLoop)
}
}
}
}
@available(*, unavailable)
@ -1352,6 +1336,522 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
}
}
// MARK: Async connect methods with NIOAsyncChannel
extension ClientBootstrap {
/// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established.
///
/// - Parameters:
/// - host: The host to connect to.
/// - port: The port to connect to.
/// - backpressureStrategy: The back pressure strategy used by the channel.
/// - isOutboundHalfClosureEnabled: Indicates if half closure is enabled on the channel. If half closure is enabled
/// then finishing the `NIOAsyncChannelWriter` will lead to half closure.
/// - inboundType: The channel's inbound type.
/// - outboundType: The channel's outbound type.
/// - Returns: A `NIOAsyncChannel` for the established connection.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func connect<Inbound: Sendable, Outbound: Sendable>(
host: String,
port: Int,
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
isOutboundHalfClosureEnabled: Bool = false,
inboundType: Inbound.Type = Inbound.self,
outboundType: Outbound.Type = Outbound.self
) async throws -> NIOAsyncChannel<Inbound, Outbound> {
return try await self.connect(
host: host,
port: port
) { channel in
channel.eventLoop.makeCompletedFuture {
return try NIOAsyncChannel(
synchronouslyWrapping: channel,
backpressureStrategy: backpressureStrategy,
isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled,
inboundType: inboundType,
outboundType: outboundType
)
}
}
}
/// Specify the `address` to connect to for the TCP `Channel` that will be established.
///
/// - Parameters:
/// - address: The address to connect to.
/// - backpressureStrategy: The back pressure strategy used by the channel.
/// - isOutboundHalfClosureEnabled: Indicates if half closure is enabled on the channel. If half closure is enabled
/// then finishing the `NIOAsyncChannelWriter` will lead to half closure.
/// - inboundType: The channel's inbound type.
/// - outboundType: The channel's outbound type.
/// - Returns: A `NIOAsyncChannel` for the established connection.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func connect<Inbound: Sendable, Outbound: Sendable>(
to address: SocketAddress,
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
isOutboundHalfClosureEnabled: Bool = false,
inboundType: Inbound.Type = Inbound.self,
outboundType: Outbound.Type = Outbound.self
) async throws -> NIOAsyncChannel<Inbound, Outbound> {
return try await self.connect(
to: address
) { channel in
channel.eventLoop.makeCompletedFuture {
return try NIOAsyncChannel(
synchronouslyWrapping: channel,
backpressureStrategy: backpressureStrategy,
isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled,
inboundType: inboundType,
outboundType: outboundType
)
}
}
}
/// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established.
///
/// - Parameters:
/// - unixDomainSocketPath: The _Unix domain socket_ path to connect to.
/// - backpressureStrategy: The back pressure strategy used by the channel.
/// - isOutboundHalfClosureEnabled: Indicates if half closure is enabled on the channel. If half closure is enabled
/// then finishing the `NIOAsyncChannelWriter` will lead to half closure.
/// - inboundType: The channel's inbound type.
/// - outboundType: The channel's outbound type.
/// - Returns: A `NIOAsyncChannel` for the established connection.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func connect<Inbound: Sendable, Outbound: Sendable>(
unixDomainSocketPath: String,
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
isOutboundHalfClosureEnabled: Bool = false,
inboundType: Inbound.Type = Inbound.self,
outboundType: Outbound.Type = Outbound.self
) async throws -> NIOAsyncChannel<Inbound, Outbound> {
return try await self.connect(
unixDomainSocketPath: unixDomainSocketPath
) { channel in
channel.eventLoop.makeCompletedFuture {
return try NIOAsyncChannel(
synchronouslyWrapping: channel,
backpressureStrategy: backpressureStrategy,
isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled,
inboundType: inboundType,
outboundType: outboundType
)
}
}
}
/// Use the existing connected socket file descriptor.
///
/// - Parameters:
/// - descriptor: The _Unix file descriptor_ representing the connected stream socket.
/// - backpressureStrategy: The back pressure strategy used by the channel.
/// - isOutboundHalfClosureEnabled: Indicates if half closure is enabled on the channel. If half closure is enabled
/// then finishing the `NIOAsyncChannelWriter` will lead to half closure.
/// - inboundType: The channel's inbound type.
/// - outboundType: The channel's outbound type.
/// - Returns: A `NIOAsyncChannel` for the established connection.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func connect<Inbound: Sendable, Outbound: Sendable>(
_ socket: NIOBSDSocket.Handle,
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
isOutboundHalfClosureEnabled: Bool = false,
inboundType: Inbound.Type = Inbound.self,
outboundType: Outbound.Type = Outbound.self
) async throws -> NIOAsyncChannel<Inbound, Outbound> {
return try await self.withConnectedSocket(
socket
) { channel in
channel.eventLoop.makeCompletedFuture {
return try NIOAsyncChannel(
synchronouslyWrapping: channel,
backpressureStrategy: backpressureStrategy,
isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled,
inboundType: inboundType,
outboundType: outboundType
)
}
}
}
}
// MARK: Async connect methods with protocol negotiation
extension ClientBootstrap {
/// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established.
///
/// - Parameters:
/// - host: The host to connect to.
/// - port: The port to connect to.
/// - channelInitializer: A closure to initialize the channel which must return the handler that is used for negotiating
/// the protocol.
/// - Returns: The protocol negotiation result.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func connect<Handler: NIOProtocolNegotiationHandler>(
host: String,
port: Int,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Handler>
) async throws -> Handler.NegotiationResult {
let eventLoop = self.group.next()
return try await self.connect(
host: host,
port: port,
eventLoop: eventLoop,
channelInitializer: channelInitializer,
postRegisterTransformation: { handler, eventLoop in
handler.protocolNegotiationResult.flatMap { result in
result.resolve(on: eventLoop)
}
}
)
}
/// Specify the `address` to connect to for the TCP `Channel` that will be established.
///
/// - Parameters:
/// - address: The address to connect to.
/// - channelInitializer: A closure to initialize the channel which must return the handler that is used for negotiating
/// the protocol.
/// - Returns: The protocol negotiation result.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func connect<Handler: NIOProtocolNegotiationHandler>(
to address: SocketAddress,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Handler>
) async throws -> Handler.NegotiationResult {
return try await self.initializeAndRegisterNewChannel(
eventLoop: self.group.next(),
protocolFamily: address.protocol,
channelInitializer: channelInitializer
) { channel in
return self.connect(freshChannel: channel, address: address)
}.get().1
}
/// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established.
///
/// - Parameters:
/// - unixDomainSocketPath: The _Unix domain socket_ path to connect to.
/// - channelInitializer: A closure to initialize the channel which must return the handler that is used for negotiating
/// the protocol.
/// - Returns: The protocol negotiation result.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func connect<Handler: NIOProtocolNegotiationHandler>(
unixDomainSocketPath: String,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Handler>
) async throws -> Handler.NegotiationResult {
let address = try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
return try await self.connect(
to: address,
channelInitializer: channelInitializer
)
}
/// Use the existing connected socket file descriptor.
///
/// - Parameters:
/// - descriptor: The _Unix file descriptor_ representing the connected stream socket.
/// - channelInitializer: A closure to initialize the channel which must return the handler that is used for negotiating
/// the protocol.
/// - Returns: The protocol negotiation result.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func withConnectedSocket<Handler: NIOProtocolNegotiationHandler>(
_ socket: NIOBSDSocket.Handle,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Handler>
) async throws -> Handler.NegotiationResult {
let eventLoop = group.next()
return try await self.withConnectedSocket(
eventLoop: eventLoop,
socket: socket,
channelInitializer: channelInitializer,
postRegisterTransformation: { handler, eventLoop in
handler.protocolNegotiationResult.flatMap { result in
result.resolve(on: eventLoop)
}
}
)
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
private func initializeAndRegisterNewChannel<Handler: NIOProtocolNegotiationHandler>(
eventLoop: EventLoop,
protocolFamily: NIOBSDSocket.ProtocolFamily,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Handler>,
_ body: @escaping (Channel) -> EventLoopFuture<Void>
) -> EventLoopFuture<(Channel, Handler.NegotiationResult)> {
self.initializeAndRegisterNewChannel(
eventLoop: eventLoop,
protocolFamily: protocolFamily,
channelInitializer: channelInitializer,
postRegisterTransformation: { handler, eventLoop in
handler.protocolNegotiationResult.flatMap { result in
result.resolve(on: eventLoop)
}
},
body
)
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
private func initializeAndRegisterChannel<Handler: NIOProtocolNegotiationHandler>(
channel: SocketChannel,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Handler>,
_ body: @escaping (Channel) -> EventLoopFuture<Void>
) -> EventLoopFuture<Handler.NegotiationResult> {
self.initializeAndRegisterChannel(
channel: channel,
channelInitializer: channelInitializer,
registration: { channel in
channel.registerAndDoSynchronously(body)
},
postRegisterTransformation: { handler, eventLoop in
handler.protocolNegotiationResult.flatMap { result in
result.resolve(on: channel.eventLoop)
}
}
)
}
}
// MARK: Async connect methods with arbitrary payload
extension ClientBootstrap {
/// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established.
///
/// - Parameters:
/// - host: The host to connect to.
/// - port: The port to connect to.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func connect<Output: Sendable>(
host: String,
port: Int,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output {
let eventLoop = self.group.next()
return try await self.connect(
host: host,
port: port,
eventLoop: eventLoop,
channelInitializer: channelInitializer,
postRegisterTransformation: { output, eventLoop in
eventLoop.makeSucceededFuture(output)
}
)
}
/// Specify the `address` to connect to for the TCP `Channel` that will be established.
///
/// - Parameters:
/// - address: The address to connect to.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func connect<Output: Sendable>(
to address: SocketAddress,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output {
let eventLoop = self.group.next()
return try await self.initializeAndRegisterNewChannel(
eventLoop: eventLoop,
protocolFamily: address.protocol,
channelInitializer: channelInitializer,
postRegisterTransformation: { output, eventLoop in
eventLoop.makeSucceededFuture(output)
}, { channel in
return self.connect(freshChannel: channel, address: address)
}).get().1
}
/// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established.
///
/// - Parameters:
/// - unixDomainSocketPath: The _Unix domain socket_ path to connect to.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func connect<Output: Sendable>(
unixDomainSocketPath: String,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output {
let address = try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
return try await self.connect(
to: address,
channelInitializer: channelInitializer
)
}
/// Use the existing connected socket file descriptor.
///
/// - Parameters:
/// - descriptor: The _Unix file descriptor_ representing the connected stream socket.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@_spi(AsyncChannel)
public func withConnectedSocket<Output: Sendable>(
_ socket: NIOBSDSocket.Handle,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output {
let eventLoop = group.next()
return try await self.withConnectedSocket(
eventLoop: eventLoop,
socket: socket,
channelInitializer: channelInitializer,
postRegisterTransformation: { output, eventLoop in
eventLoop.makeSucceededFuture(output)
}
)
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
func connect<ChannelInitializerResult, PostRegistrationTransformationResult>(
host: String,
port: Int,
eventLoop: EventLoop,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture<PostRegistrationTransformationResult>
) async throws -> PostRegistrationTransformationResult {
let resolver = self.resolver ?? GetaddrinfoResolver(
loop: eventLoop,
aiSocktype: .stream,
aiProtocol: .tcp
)
let connector = HappyEyeballsConnector<PostRegistrationTransformationResult>(
resolver: resolver,
loop: eventLoop,
host: host,
port: port,
connectTimeout: self.connectTimeout
) { eventLoop, protocolFamily in
return self.initializeAndRegisterNewChannel(
eventLoop: eventLoop,
protocolFamily: protocolFamily,
channelInitializer: channelInitializer,
postRegisterTransformation: postRegisterTransformation
) {
$0.eventLoop.makeSucceededFuture(())
}
}
return try await connector.resolveAndConnect().map { $0.1 }.get()
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
private func withConnectedSocket<ChannelInitializerResult, PostRegistrationTransformationResult>(
eventLoop: EventLoop,
socket: NIOBSDSocket.Handle,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture<PostRegistrationTransformationResult>
) async throws -> PostRegistrationTransformationResult {
let channel = try SocketChannel(eventLoop: eventLoop as! SelectableEventLoop, socket: socket)
return try await self.initializeAndRegisterChannel(
channel: channel,
channelInitializer: channelInitializer,
registration: { channel in
let promise = eventLoop.makePromise(of: Void.self)
channel.registerAlreadyConfigured0(promise: promise)
return promise.futureResult
},
postRegisterTransformation: postRegisterTransformation
).get()
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
private func initializeAndRegisterNewChannel<ChannelInitializerResult, PostRegistrationTransformationResult>(
eventLoop: EventLoop,
protocolFamily: NIOBSDSocket.ProtocolFamily,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture<PostRegistrationTransformationResult>,
_ body: @escaping (Channel) -> EventLoopFuture<Void>
) -> EventLoopFuture<(Channel, PostRegistrationTransformationResult)> {
let channel: SocketChannel
do {
channel = try self.makeSocketChannel(eventLoop: eventLoop, protocolFamily: protocolFamily)
} catch {
return eventLoop.makeFailedFuture(error)
}
return self.initializeAndRegisterChannel(
channel: channel,
channelInitializer: channelInitializer,
registration: { channel in
channel.registerAndDoSynchronously(body)
},
postRegisterTransformation: postRegisterTransformation
).map { (channel, $0) }
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
private func initializeAndRegisterChannel<ChannelInitializerResult, PostRegistrationTransformationResult>(
channel: SocketChannel,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
registration: @escaping @Sendable (Channel) -> EventLoopFuture<Void>,
postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture<PostRegistrationTransformationResult>
) -> EventLoopFuture<PostRegistrationTransformationResult> {
let channelInitializer = { channel in
return self.channelInitializer(channel)
.flatMap { channelInitializer(channel) }
}
let channelOptions = self._channelOptions
let eventLoop = channel.eventLoop
let bindTarget = self.bindTarget
@inline(__always)
@Sendable
func setupChannel() -> EventLoopFuture<PostRegistrationTransformationResult> {
eventLoop.assertInEventLoop()
return channelOptions
.applyAllChannelOptions(to: channel)
.flatMap {
if let bindTarget = bindTarget {
return channel
.bind(to: bindTarget)
.flatMap {
channelInitializer(channel)
}
} else {
return channelInitializer(channel)
}
}.flatMap { (result: ChannelInitializerResult) in
eventLoop.assertInEventLoop()
return registration(channel).map {
result
}
}.flatMap { (result: ChannelInitializerResult) -> EventLoopFuture<PostRegistrationTransformationResult> in
postRegisterTransformation(result, eventLoop)
}.flatMapError { error in
eventLoop.assertInEventLoop()
channel.close0(error: error, mode: .all, promise: nil)
return channel.eventLoop.makeFailedFuture(error)
}
}
if eventLoop.inEventLoop {
return setupChannel()
} else {
return eventLoop.flatSubmit {
setupChannel()
}
}
}
}
@available(*, unavailable)
extension ClientBootstrap: Sendable {}
@ -1839,5 +2339,23 @@ public final class NIOPipeBootstrap {
}
}
extension NIOProtocolNegotiationResult {
func resolve(on eventLoop: EventLoop) -> EventLoopFuture<NegotiationResult> {
Self.resolve(on: eventLoop, result: self)
}
static func resolve(on eventLoop: EventLoop, result: Self) -> EventLoopFuture<NegotiationResult> {
switch result {
case .finished(let negotiationResult):
return eventLoop.makeSucceededFuture(negotiationResult)
case .deferredResult(let future):
return future.flatMap { result in
return resolve(on: eventLoop, result: result)
}
}
}
}
@available(*, unavailable)
extension NIOPipeBootstrap: Sendable {}

View File

@ -141,7 +141,10 @@ private struct TargetIterator: IteratorProtocol {
///
/// This class's private API is *not* thread-safe, and expects to be called from the
/// event loop thread of the `loop` it is passed.
internal class HappyEyeballsConnector {
///
/// The `ChannelBuilderResult` generic type can used to tunnel an arbitrary type
/// from the `channelBuilderCallback` to the `resolve` methods return value.
internal final class HappyEyeballsConnector<ChannelBuilderResult> {
/// An enum for keeping track of connection state.
private enum ConnectionState {
/// Initial state. No work outstanding.
@ -223,7 +226,7 @@ internal class HappyEyeballsConnector {
/// than intended.
///
/// The channel builder callback takes an event loop and a protocol family as arguments.
private let channelBuilderCallback: (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<Channel>
private let channelBuilderCallback: (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<(Channel, ChannelBuilderResult)>
/// The amount of time to wait for an AAAA response to come in after a A response is
/// received. By default this is 50ms.
@ -250,7 +253,7 @@ internal class HappyEyeballsConnector {
private var timeoutTask: Optional<Scheduled<Void>>
/// The promise that will hold the final connected channel.
private let resolutionPromise: EventLoopPromise<Channel>
private let resolutionPromise: EventLoopPromise<(Channel, ChannelBuilderResult)>
/// Our state machine state.
private var state: ConnectionState
@ -263,7 +266,7 @@ internal class HappyEyeballsConnector {
///
/// This is kept to ensure that we can clean up after ourselves once a connection succeeds,
/// and throw away all pending connection attempts that are no longer needed.
private var pendingConnections: [EventLoopFuture<Channel>] = []
private var pendingConnections: [EventLoopFuture<(Channel, ChannelBuilderResult)>] = []
/// The number of DNS resolutions that have returned.
///
@ -274,6 +277,7 @@ internal class HappyEyeballsConnector {
/// An object that holds any errors we encountered.
private var error: NIOConnectionError
@inlinable
init(resolver: Resolver,
loop: EventLoop,
host: String,
@ -281,7 +285,7 @@ internal class HappyEyeballsConnector {
connectTimeout: TimeAmount,
resolutionDelay: TimeAmount = .milliseconds(50),
connectionDelay: TimeAmount = .milliseconds(250),
channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<Channel>) {
channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<(Channel, ChannelBuilderResult)>) {
self.resolver = resolver
self.loop = loop
self.host = host
@ -303,10 +307,34 @@ internal class HappyEyeballsConnector {
self.connectionDelay = connectionDelay
}
@inlinable
convenience init(
resolver: Resolver,
loop: EventLoop,
host: String,
port: Int,
connectTimeout: TimeAmount,
resolutionDelay: TimeAmount = .milliseconds(50),
connectionDelay: TimeAmount = .milliseconds(250),
channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<Channel>
) where ChannelBuilderResult == Void {
self.init(
resolver: resolver,
loop: loop,
host: host,
port: port,
connectTimeout: connectTimeout,
resolutionDelay: resolutionDelay,
connectionDelay: connectionDelay) { loop, protocolFamily in
channelBuilderCallback(loop, protocolFamily).map { ($0, ()) }
}
}
/// Initiate a DNS resolution attempt using Happy Eyeballs 2.
///
/// returns: An `EventLoopFuture` that fires with a connected `Channel`.
public func resolveAndConnect() -> EventLoopFuture<Channel> {
@inlinable
func resolveAndConnect() -> EventLoopFuture<(Channel, ChannelBuilderResult)> {
// We dispatch ourselves onto the event loop, rather than do all the rest of our processing from outside it.
self.loop.execute {
self.timeoutTask = self.loop.scheduleTask(in: self.connectTimeout) { self.processInput(.connectTimeoutElapsed) }
@ -315,6 +343,14 @@ internal class HappyEyeballsConnector {
return resolutionPromise.futureResult
}
/// Initiate a DNS resolution attempt using Happy Eyeballs 2.
///
/// returns: An `EventLoopFuture` that fires with a connected `Channel`.
@inlinable
func resolveAndConnect() -> EventLoopFuture<Channel> where ChannelBuilderResult == Void {
self.resolveAndConnect().map { $0.0 }
}
/// Spin the state machine.
///
/// - parameters:
@ -540,11 +576,11 @@ internal class HappyEyeballsConnector {
let channelFuture = channelBuilderCallback(self.loop, target.protocol)
pendingConnections.append(channelFuture)
channelFuture.whenSuccess { channel in
channelFuture.whenSuccess { (channel, result) in
// If we are in the complete state then we want to abandon this channel. Otherwise, begin
// connecting.
if case .complete = self.state {
self.pendingConnections.remove(element: channelFuture)
self.pendingConnections.removeAll { $0 === channelFuture }
channel.close(promise: nil)
} else {
channel.connect(to: target).map {
@ -552,13 +588,13 @@ internal class HappyEyeballsConnector {
// Otherwise, fire the channel connected event. Either way we don't want the channel future to
// be in our list of pending connections, so we don't either double close or close the connection
// we want to use.
self.pendingConnections.remove(element: channelFuture)
self.pendingConnections.removeAll { $0 === channelFuture }
if case .complete = self.state {
channel.close(promise: nil)
} else {
self.processInput(.connectSuccess)
self.resolutionPromise.succeed(channel)
self.resolutionPromise.succeed((channel, result))
}
}.whenFailure { err in
// The connection attempt failed. If we're in the complete state then there's nothing
@ -567,7 +603,7 @@ internal class HappyEyeballsConnector {
assert(self.pendingConnections.firstIndex { $0 === channelFuture } == nil, "failed but was still in pending connections")
} else {
self.error.connectionErrors.append(SingleConnectionFailure(target: target, error: err))
self.pendingConnections.remove(element: channelFuture)
self.pendingConnections.removeAll { $0 === channelFuture }
self.processInput(.connectFailed)
}
}
@ -575,7 +611,7 @@ internal class HappyEyeballsConnector {
}
channelFuture.whenFailure { error in
self.error.connectionErrors.append(SingleConnectionFailure(target: target, error: error))
self.pendingConnections.remove(element: channelFuture)
self.pendingConnections.removeAll { $0 === channelFuture }
self.processInput(.connectFailed)
}
}
@ -607,7 +643,7 @@ internal class HappyEyeballsConnector {
let connections = self.pendingConnections
self.pendingConnections = []
for connection in connections {
connection.whenSuccess { channel in channel.close(promise: nil) }
connection.whenSuccess { (channel, _) in channel.close(promise: nil) }
}
}

View File

@ -18,7 +18,7 @@ import NIOConcurrencyHelpers
import XCTest
@_spi(AsyncChannel) import NIOTLS
private final class LineDelimiterDecoder: ByteToMessageDecoder {
private final class LineDelimiterCoder: ByteToMessageDecoder, MessageToByteEncoder {
private let newLine = "\n".utf8.first!
typealias InboundIn = ByteBuffer
@ -33,43 +33,100 @@ private final class LineDelimiterDecoder: ByteToMessageDecoder {
}
return .needMoreData
}
func encode(data: ByteBuffer, out: inout ByteBuffer) throws {
out.writeImmutableBuffer(data)
out.writeString("\n")
}
}
private final class TLSUserEventHandler: ChannelInboundHandler {
private final class TLSUserEventHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = ByteBuffer
typealias InboundOut = ByteBuffer
enum ALPN: String {
case string
case byte
case unknown
}
private var proposedALPN: ALPN?
init(
proposedALPN: ALPN? = nil
) {
self.proposedALPN = proposedALPN
}
func handlerAdded(context: ChannelHandlerContext) {
guard context.channel.isActive else {
return
}
if let proposedALPN = self.proposedALPN {
self.proposedALPN = nil
context.writeAndFlush(.init(ByteBuffer(string: "negotiate-alpn:\(proposedALPN.rawValue)")), promise: nil)
}
context.fireChannelActive()
}
func channelActive(context: ChannelHandlerContext) {
if let proposedALPN = self.proposedALPN {
context.writeAndFlush(.init(ByteBuffer(string: "negotiate-alpn:\(proposedALPN.rawValue)")), promise: nil)
}
context.fireChannelActive()
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let buffer = self.unwrapInboundIn(data)
let alpn = String(buffer: buffer)
let string = String(buffer: buffer)
if alpn.hasPrefix("alpn:") {
context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: String(alpn.dropFirst(5))))
if string.hasPrefix("negotiate-alpn:") {
let alpn = String(string.dropFirst(15))
context.writeAndFlush(.init(ByteBuffer(string: "alpn:\(alpn)")), promise: nil)
context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: alpn))
context.pipeline.removeHandler(self, promise: nil)
} else if string.hasPrefix("alpn:") {
context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: String(string.dropFirst(5))))
context.pipeline.removeHandler(self, promise: nil)
} else {
context.fireChannelRead(data)
}
}
}
private final class ByteBufferToStringHandler: ChannelInboundHandler {
private final class ByteBufferToStringHandler: ChannelDuplexHandler {
typealias InboundIn = ByteBuffer
typealias InboundOut = String
typealias OutboundIn = String
typealias OutboundOut = ByteBuffer
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let buffer = self.unwrapInboundIn(data)
context.fireChannelRead(self.wrapInboundOut(String(buffer: buffer)))
}
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let buffer = ByteBuffer(string: self.unwrapOutboundIn(data))
context.write(.init(buffer), promise: promise)
}
}
private final class ByteBufferToByteHandler: ChannelInboundHandler {
private final class ByteBufferToByteHandler: ChannelDuplexHandler {
typealias InboundIn = ByteBuffer
typealias InboundOut = UInt8
typealias OutboundIn = UInt8
typealias OutboundOut = ByteBuffer
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
var buffer = self.unwrapInboundIn(data)
let byte = buffer.readInteger(as: UInt8.self)!
context.fireChannelRead(self.wrapInboundOut(byte))
}
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let buffer = ByteBuffer(integer: self.unwrapOutboundIn(data))
context.write(.init(buffer), promise: promise)
}
}
final class AsyncChannelBootstrapTests: XCTestCase {
@ -94,7 +151,8 @@ final class AsyncChannelBootstrapTests: XCTestCase {
.childChannelOption(ChannelOptions.autoRead, value: true)
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
}
}
@ -120,7 +178,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
}
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
stringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
try await stringChannel.outboundWriter.write("hello")
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
@ -138,7 +196,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
.childChannelOption(ChannelOptions.autoRead, value: true)
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try self.makeProtocolNegotiationChildChannel(channel: channel)
try self.configureProtocolNegotiationHandlers(channel: channel)
}
}
.bind(
@ -149,7 +207,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
try await withThrowingTaskGroup(of: Void.self) { group in
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
var iterator = stream.makeAsyncIterator()
var serverIterator = stream.makeAsyncIterator()
group.addTask {
try await withThrowingTaskGroup(of: Void.self) { group in
@ -170,26 +228,34 @@ final class AsyncChannelBootstrapTests: XCTestCase {
}
}
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
let stringNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedALPN: .string
)
switch stringNegotiationResult {
case .string(let stringChannel):
// This is the actual content
try await stringChannel.outboundWriter.write("hello")
await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello"))
case .byte:
preconditionFailure()
}
// This is for negotiating the protocol
stringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
// This is the actual content
stringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
let byteChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
// This is for negotiating the protocol
byteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
// This is the actual content
byteChannel.write(.init(ByteBuffer(integer: UInt8(8))), promise: nil)
byteChannel.writeAndFlush(.init(ByteBuffer(string: "\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .byte(8))
let byteNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedALPN: .byte
)
switch byteNegotiationResult {
case .string:
preconditionFailure()
case .byte(let byteChannel):
// This is the actual content
try await byteChannel.outboundWriter.write(UInt8(8))
await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8))
}
group.cancelAll()
}
@ -205,7 +271,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
.childChannelOption(ChannelOptions.autoRead, value: true)
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try self.makeNestedProtocolNegotiationChildChannel(channel: channel)
try self.configureNestedProtocolNegotiationHandlers(channel: channel)
}
}
.bind(
@ -216,7 +282,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
try await withThrowingTaskGroup(of: Void.self) { group in
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
var iterator = stream.makeAsyncIterator()
var serverIterator = stream.makeAsyncIterator()
group.addTask {
try await withThrowingTaskGroup(of: Void.self) { group in
@ -237,59 +303,65 @@ final class AsyncChannelBootstrapTests: XCTestCase {
}
}
let stringStringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
let stringStringNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedOuterALPN: .string,
proposedInnerALPN: .string
)
switch stringStringNegotiationResult {
case .string(let stringChannel):
// This is the actual content
try await stringChannel.outboundWriter.write("hello")
await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello"))
case .byte:
preconditionFailure()
}
// This is for negotiating the protocol
stringStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
let byteStringNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedOuterALPN: .byte,
proposedInnerALPN: .string
)
switch byteStringNegotiationResult {
case .string(let stringChannel):
// This is the actual content
try await stringChannel.outboundWriter.write("hello")
await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello"))
case .byte:
preconditionFailure()
}
// This is for negotiating the nested protocol
stringStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
let byteByteNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedOuterALPN: .byte,
proposedInnerALPN: .byte
)
switch byteByteNegotiationResult {
case .string:
preconditionFailure()
case .byte(let byteChannel):
// This is the actual content
try await byteChannel.outboundWriter.write(UInt8(8))
await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8))
}
// This is the actual content
stringStringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
let byteByteChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
// This is for negotiating the protocol
byteByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
// This is for negotiating the nested protocol
byteByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
// This is the actual content
byteByteChannel.write(.init(ByteBuffer(integer: UInt8(8))), promise: nil)
byteByteChannel.writeAndFlush(.init(ByteBuffer(string: "\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .byte(8))
let stringByteChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
// This is for negotiating the protocol
stringByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
// This is for negotiating the nested protocol
stringByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
// This is the actual content
stringByteChannel.write(.init(ByteBuffer(integer: UInt8(8))), promise: nil)
stringByteChannel.writeAndFlush(.init(ByteBuffer(string: "\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .byte(8))
let byteStringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
// This is for negotiating the protocol
byteStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
// This is for negotiating the nested protocol
byteStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
// This is the actual content
byteStringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
let stringByteNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedOuterALPN: .string,
proposedInnerALPN: .byte
)
switch stringByteNegotiationResult {
case .string:
preconditionFailure()
case .byte(let byteChannel):
// This is the actual content
try await byteChannel.outboundWriter.write(UInt8(8))
await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8))
}
group.cancelAll()
}
@ -328,7 +400,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
.childChannelOption(ChannelOptions.autoRead, value: true)
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try self.makeProtocolNegotiationChildChannel(channel: channel)
try self.configureProtocolNegotiationHandlers(channel: channel)
}
}
.bind(
@ -339,7 +411,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
try await withThrowingTaskGroup(of: Void.self) { group in
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
var iterator = stream.makeAsyncIterator()
var serverIterator = stream.makeAsyncIterator()
group.addTask {
try await withThrowingTaskGroup(of: Void.self) { group in
@ -360,21 +432,30 @@ final class AsyncChannelBootstrapTests: XCTestCase {
}
}
let unknownChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
await XCTAssertThrowsError(
try await self.makeClientChannelWithProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedALPN: .unknown
)
) { error in
XCTAssertTrue(error is ProtocolNegotiationError)
}
// This is for negotiating the protocol
unknownChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:unknown\n")), promise: nil)
// Checking that we can still create new connections afterwards
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
// This is for negotiating the protocol
stringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
// This is the actual content
stringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
// Let's check that we can still open a new connection
let stringNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedALPN: .string
)
switch stringNegotiationResult {
case .string(let stringChannel):
// This is the actual content
try await stringChannel.outboundWriter.write("hello")
await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello"))
case .byte:
preconditionFailure()
}
let failedInboundChannel = channels.withLockedValue { channels -> Channel in
XCTAssertEqual(channels.count, 2)
@ -391,55 +472,108 @@ final class AsyncChannelBootstrapTests: XCTestCase {
// MARK: - Test Helpers
private func makeClientChannel(eventLoopGroup: EventLoopGroup, port: Int) async throws -> Channel {
private func makeClientChannel(eventLoopGroup: EventLoopGroup, port: Int) async throws -> NIOAsyncChannel<String, String> {
return try await ClientBootstrap(group: eventLoopGroup)
.channelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
}
}
.connect(to: .init(ipAddress: "127.0.0.1", port: port))
.get()
.connect(
to: .init(ipAddress: "127.0.0.1", port: port),
inboundType: String.self,
outboundType: String.self
)
}
private func makeProtocolNegotiationChildChannel(channel: Channel) throws {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler())
try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
}
private func makeNestedProtocolNegotiationChildChannel(channel: Channel) throws {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler())
try channel.pipeline.syncOperations.addHandler(
NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: channel.eventLoop) { alpnResult, channel in
switch alpnResult {
case .negotiated(let alpn):
switch alpn {
case "string":
return channel.eventLoop.makeCompletedFuture {
let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
return NIOProtocolNegotiationResult.deferredResult(negotiationFuture)
}
case "byte":
return channel.eventLoop.makeCompletedFuture {
let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
return NIOProtocolNegotiationResult.deferredResult(negotiationFuture)
}
default:
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
}
case .fallback:
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
private func makeClientChannelWithProtocolNegotiation(
eventLoopGroup: EventLoopGroup,
port: Int,
proposedALPN: TLSUserEventHandler.ALPN
) async throws -> NegotiationResult {
return try await ClientBootstrap(group: eventLoopGroup)
.connect(
to: .init(ipAddress: "127.0.0.1", port: port)
) { channel in
return channel.eventLoop.makeCompletedFuture {
return try self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: proposedALPN)
}
}
}
private func makeClientChannelWithNestedProtocolNegotiation(
eventLoopGroup: EventLoopGroup,
port: Int,
proposedOuterALPN: TLSUserEventHandler.ALPN,
proposedInnerALPN: TLSUserEventHandler.ALPN
) async throws -> NegotiationResult {
return try await ClientBootstrap(group: eventLoopGroup)
.connect(
to: .init(ipAddress: "127.0.0.1", port: port)
) { channel in
return channel.eventLoop.makeCompletedFuture {
try self.configureNestedProtocolNegotiationHandlers(
channel: channel,
proposedOuterALPN: proposedOuterALPN,
proposedInnerALPN: proposedInnerALPN
)
}
}
)
}
@discardableResult
private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>> {
private func configureProtocolNegotiationHandlers(
channel: Channel,
proposedALPN: TLSUserEventHandler.ALPN? = nil
) throws -> NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedALPN))
return try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
}
@discardableResult
private func configureNestedProtocolNegotiationHandlers(
channel: Channel,
proposedOuterALPN: TLSUserEventHandler.ALPN? = nil,
proposedInnerALPN: TLSUserEventHandler.ALPN? = nil
) throws -> NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedOuterALPN))
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: channel.eventLoop) { alpnResult, channel in
switch alpnResult {
case .negotiated(let alpn):
switch alpn {
case "string":
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN))
let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
return NIOProtocolNegotiationResult.deferredResult(negotiationFuture.protocolNegotiationResult)
}
case "byte":
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN))
let negotiationHandler = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
return NIOProtocolNegotiationResult.deferredResult(negotiationHandler.protocolNegotiationResult)
}
default:
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
}
case .fallback:
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
}
}
try channel.pipeline.syncOperations.addHandler(negotiationHandler)
return negotiationHandler
}
@discardableResult
private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> {
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: channel.eventLoop) { alpnResult, channel in
switch alpnResult {
case .negotiated(let alpn):
@ -478,7 +612,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
}
try channel.pipeline.syncOperations.addHandler(negotiationHandler)
return negotiationHandler.protocolNegotiationResult
return negotiationHandler
}
}
@ -499,3 +633,19 @@ private func XCTAsyncAssertEqual<Element: Equatable>(_ lhs: @autoclosure () asyn
let rhsResult = try await rhs()
XCTAssertEqual(lhsResult, rhsResult, file: file, line: line)
}
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
private func XCTAsyncAssertThrowsError<T>(
_ expression: @autoclosure () async throws -> T,
_ message: @autoclosure () -> String = "",
file: StaticString = #filePath,
line: UInt = #line,
_ errorHandler: (_ error: Error) -> Void = { _ in }
) async {
do {
_ = try await expression()
XCTFail(message(), file: file, line: line)
} catch {
errorHandler(error)
}
}

View File

@ -226,10 +226,12 @@ private func defaultChannelBuilder(loop: EventLoop, family: NIOBSDSocket.Protoco
return loop.makeSucceededFuture(channel)
}
private func buildEyeballer(host: String,
port: Int,
connectTimeout: TimeAmount = .seconds(10),
channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<Channel> = defaultChannelBuilder) -> (eyeballer: HappyEyeballsConnector, resolver: DummyResolver, loop: EmbeddedEventLoop) {
private func buildEyeballer(
host: String,
port: Int,
connectTimeout: TimeAmount = .seconds(10),
channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<Channel> = defaultChannelBuilder
) -> (eyeballer: HappyEyeballsConnector<Void>, resolver: DummyResolver, loop: EmbeddedEventLoop) {
let loop = EmbeddedEventLoop()
let resolver = DummyResolver(loop: loop)
let eyeballer = HappyEyeballsConnector(resolver: resolver,