Add NIOAsyncChannel based connect methods to ClientBootstrap (#2437)
* Add NIOAsyncChannel based connect methods to ClientBootstrap # Motivation In my previous PR, I added new `bind` methods to `ServerBootstrap` that vend `NIOAsyncChannel` or support an async protocol negotiation. This PR focuses on adding new `connect` methods to `ClientBootstrap` which offer the same functionality. # Modification This PR adds new `connect` methods that either vend a `NIOAsyncChannel` or an asynchronous protocol negotiation result. To make this work I had to change the `HappyEyeballs` resolver so that it can return a generic value on resolving. Lastly, I adapted the bootstrap tests to use the new `ClientBootstrap` capabilities which now demonstrate a client/server protocol negotiation dance. # Result We can now bootstrap TCP clients with `NIOAsyncChannel`s * Reduce code duplication * Create a new set of APIs to tunnel an arbitrary Sendable payload through the inits * Pass EL to closure * Fix documentation
This commit is contained in:
parent
6213ba7a06
commit
46c0538253
|
@ -142,7 +142,25 @@ Afterwards, we handle each inbound connection in separate child tasks and echo t
|
|||
Normal task groups will result in a memory leak since they do not reap their child tasks automatically.
|
||||
|
||||
#### ClientBootstrap
|
||||
> Important: Support for `ClientBootstrap` with `NIOAsyncChannel` hasn't landed yet.
|
||||
|
||||
The client bootstrap is used to create a new TCP based client. Let's take a look at the new
|
||||
`NIOAsyncChannel` based connect methods.
|
||||
|
||||
```swift
|
||||
let clientChannel = try await ClientBootstrap(group: eventLoopGroup)
|
||||
.connect(
|
||||
host: "127.0.0.1",
|
||||
port: 0,
|
||||
channelInboundType: ByteBuffer.self,
|
||||
channelOutboundType: ByteBuffer.self
|
||||
)
|
||||
|
||||
clientChannel.outboundWriter.write(ByteBuffer(string: "hello"))
|
||||
|
||||
for try await inboundData in clientChannel.inboundStream {
|
||||
print(inboundData)
|
||||
}
|
||||
```
|
||||
|
||||
#### DatagramBootstrap
|
||||
> Important: Support for `DatagramBootstrap` with `NIOAsyncChannel` hasn't landed yet.
|
||||
|
@ -158,7 +176,7 @@ To solve the problem of protocol negotiation, NIO introduced a new ``ChannelHand
|
|||
that is completed once the handler is finished with protocol negotiation. In the successful case,
|
||||
the future can either indicate that protocol negotiation is fully done by returning `NIOProtocolNegotiationResult/finished(_:)` or
|
||||
indicate that further protocol negotiation needs to be done by returning `NIOProtocolNegotiationResult/deferredResult(_:)`.
|
||||
Additionally, the various bootstraps provide another set of `bind()` methods that handle protocol negotiation.
|
||||
Additionally, the various bootstraps provide another set of `bind()`/`connect()` methods that handle protocol negotiation.
|
||||
Let's walk through how to setup a `ServerBootstrap` with protocol negotiation.
|
||||
|
||||
First, we have to define our negotiation result. For this example, we are negotiating between a
|
||||
|
|
|
@ -892,7 +892,7 @@ extension ServerBootstrap {
|
|||
.flatMap { handler -> EventLoopFuture<NIOProtocolNegotiationResult<Handler.NegotiationResult>> in
|
||||
handler.protocolNegotiationResult
|
||||
}.flatMap { result in
|
||||
ServerBootstrap.waitForFinalResult(result, eventLoop: eventLoop)
|
||||
result.resolve(on: eventLoop)
|
||||
}.flatMapErrorThrowing { error in
|
||||
channel.pipeline.fireErrorCaught(error)
|
||||
channel.close(promise: nil)
|
||||
|
@ -913,22 +913,6 @@ extension ServerBootstrap {
|
|||
$0
|
||||
}.get()
|
||||
}
|
||||
|
||||
/// This method recursively waits for the final result of protocol negotiation
|
||||
static func waitForFinalResult<NegotiationResult>(
|
||||
_ result: NIOProtocolNegotiationResult<NegotiationResult>,
|
||||
eventLoop: EventLoop
|
||||
) -> EventLoopFuture<NegotiationResult> {
|
||||
switch result {
|
||||
case .finished(let negotiationResult):
|
||||
return eventLoop.makeSucceededFuture(negotiationResult)
|
||||
|
||||
case .deferredResult(let future):
|
||||
return future.flatMap { result in
|
||||
return waitForFinalResult(result, eventLoop: eventLoop)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@available(*, unavailable)
|
||||
|
@ -1352,6 +1336,522 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
|
|||
}
|
||||
}
|
||||
|
||||
// MARK: Async connect methods with NIOAsyncChannel
|
||||
|
||||
extension ClientBootstrap {
|
||||
/// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - host: The host to connect to.
|
||||
/// - port: The port to connect to.
|
||||
/// - backpressureStrategy: The back pressure strategy used by the channel.
|
||||
/// - isOutboundHalfClosureEnabled: Indicates if half closure is enabled on the channel. If half closure is enabled
|
||||
/// then finishing the `NIOAsyncChannelWriter` will lead to half closure.
|
||||
/// - inboundType: The channel's inbound type.
|
||||
/// - outboundType: The channel's outbound type.
|
||||
/// - Returns: A `NIOAsyncChannel` for the established connection.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@_spi(AsyncChannel)
|
||||
public func connect<Inbound: Sendable, Outbound: Sendable>(
|
||||
host: String,
|
||||
port: Int,
|
||||
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||
isOutboundHalfClosureEnabled: Bool = false,
|
||||
inboundType: Inbound.Type = Inbound.self,
|
||||
outboundType: Outbound.Type = Outbound.self
|
||||
) async throws -> NIOAsyncChannel<Inbound, Outbound> {
|
||||
return try await self.connect(
|
||||
host: host,
|
||||
port: port
|
||||
) { channel in
|
||||
channel.eventLoop.makeCompletedFuture {
|
||||
return try NIOAsyncChannel(
|
||||
synchronouslyWrapping: channel,
|
||||
backpressureStrategy: backpressureStrategy,
|
||||
isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled,
|
||||
inboundType: inboundType,
|
||||
outboundType: outboundType
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Specify the `address` to connect to for the TCP `Channel` that will be established.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - address: The address to connect to.
|
||||
/// - backpressureStrategy: The back pressure strategy used by the channel.
|
||||
/// - isOutboundHalfClosureEnabled: Indicates if half closure is enabled on the channel. If half closure is enabled
|
||||
/// then finishing the `NIOAsyncChannelWriter` will lead to half closure.
|
||||
/// - inboundType: The channel's inbound type.
|
||||
/// - outboundType: The channel's outbound type.
|
||||
/// - Returns: A `NIOAsyncChannel` for the established connection.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@_spi(AsyncChannel)
|
||||
public func connect<Inbound: Sendable, Outbound: Sendable>(
|
||||
to address: SocketAddress,
|
||||
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||
isOutboundHalfClosureEnabled: Bool = false,
|
||||
inboundType: Inbound.Type = Inbound.self,
|
||||
outboundType: Outbound.Type = Outbound.self
|
||||
) async throws -> NIOAsyncChannel<Inbound, Outbound> {
|
||||
return try await self.connect(
|
||||
to: address
|
||||
) { channel in
|
||||
channel.eventLoop.makeCompletedFuture {
|
||||
return try NIOAsyncChannel(
|
||||
synchronouslyWrapping: channel,
|
||||
backpressureStrategy: backpressureStrategy,
|
||||
isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled,
|
||||
inboundType: inboundType,
|
||||
outboundType: outboundType
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - unixDomainSocketPath: The _Unix domain socket_ path to connect to.
|
||||
/// - backpressureStrategy: The back pressure strategy used by the channel.
|
||||
/// - isOutboundHalfClosureEnabled: Indicates if half closure is enabled on the channel. If half closure is enabled
|
||||
/// then finishing the `NIOAsyncChannelWriter` will lead to half closure.
|
||||
/// - inboundType: The channel's inbound type.
|
||||
/// - outboundType: The channel's outbound type.
|
||||
/// - Returns: A `NIOAsyncChannel` for the established connection.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@_spi(AsyncChannel)
|
||||
public func connect<Inbound: Sendable, Outbound: Sendable>(
|
||||
unixDomainSocketPath: String,
|
||||
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||
isOutboundHalfClosureEnabled: Bool = false,
|
||||
inboundType: Inbound.Type = Inbound.self,
|
||||
outboundType: Outbound.Type = Outbound.self
|
||||
) async throws -> NIOAsyncChannel<Inbound, Outbound> {
|
||||
return try await self.connect(
|
||||
unixDomainSocketPath: unixDomainSocketPath
|
||||
) { channel in
|
||||
channel.eventLoop.makeCompletedFuture {
|
||||
return try NIOAsyncChannel(
|
||||
synchronouslyWrapping: channel,
|
||||
backpressureStrategy: backpressureStrategy,
|
||||
isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled,
|
||||
inboundType: inboundType,
|
||||
outboundType: outboundType
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Use the existing connected socket file descriptor.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - descriptor: The _Unix file descriptor_ representing the connected stream socket.
|
||||
/// - backpressureStrategy: The back pressure strategy used by the channel.
|
||||
/// - isOutboundHalfClosureEnabled: Indicates if half closure is enabled on the channel. If half closure is enabled
|
||||
/// then finishing the `NIOAsyncChannelWriter` will lead to half closure.
|
||||
/// - inboundType: The channel's inbound type.
|
||||
/// - outboundType: The channel's outbound type.
|
||||
/// - Returns: A `NIOAsyncChannel` for the established connection.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@_spi(AsyncChannel)
|
||||
public func connect<Inbound: Sendable, Outbound: Sendable>(
|
||||
_ socket: NIOBSDSocket.Handle,
|
||||
backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
|
||||
isOutboundHalfClosureEnabled: Bool = false,
|
||||
inboundType: Inbound.Type = Inbound.self,
|
||||
outboundType: Outbound.Type = Outbound.self
|
||||
) async throws -> NIOAsyncChannel<Inbound, Outbound> {
|
||||
return try await self.withConnectedSocket(
|
||||
socket
|
||||
) { channel in
|
||||
channel.eventLoop.makeCompletedFuture {
|
||||
return try NIOAsyncChannel(
|
||||
synchronouslyWrapping: channel,
|
||||
backpressureStrategy: backpressureStrategy,
|
||||
isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled,
|
||||
inboundType: inboundType,
|
||||
outboundType: outboundType
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: Async connect methods with protocol negotiation
|
||||
extension ClientBootstrap {
|
||||
/// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - host: The host to connect to.
|
||||
/// - port: The port to connect to.
|
||||
/// - channelInitializer: A closure to initialize the channel which must return the handler that is used for negotiating
|
||||
/// the protocol.
|
||||
/// - Returns: The protocol negotiation result.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@_spi(AsyncChannel)
|
||||
public func connect<Handler: NIOProtocolNegotiationHandler>(
|
||||
host: String,
|
||||
port: Int,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Handler>
|
||||
) async throws -> Handler.NegotiationResult {
|
||||
let eventLoop = self.group.next()
|
||||
return try await self.connect(
|
||||
host: host,
|
||||
port: port,
|
||||
eventLoop: eventLoop,
|
||||
channelInitializer: channelInitializer,
|
||||
postRegisterTransformation: { handler, eventLoop in
|
||||
handler.protocolNegotiationResult.flatMap { result in
|
||||
result.resolve(on: eventLoop)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// Specify the `address` to connect to for the TCP `Channel` that will be established.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - address: The address to connect to.
|
||||
/// - channelInitializer: A closure to initialize the channel which must return the handler that is used for negotiating
|
||||
/// the protocol.
|
||||
/// - Returns: The protocol negotiation result.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@_spi(AsyncChannel)
|
||||
public func connect<Handler: NIOProtocolNegotiationHandler>(
|
||||
to address: SocketAddress,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Handler>
|
||||
) async throws -> Handler.NegotiationResult {
|
||||
return try await self.initializeAndRegisterNewChannel(
|
||||
eventLoop: self.group.next(),
|
||||
protocolFamily: address.protocol,
|
||||
channelInitializer: channelInitializer
|
||||
) { channel in
|
||||
return self.connect(freshChannel: channel, address: address)
|
||||
}.get().1
|
||||
}
|
||||
|
||||
/// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - unixDomainSocketPath: The _Unix domain socket_ path to connect to.
|
||||
/// - channelInitializer: A closure to initialize the channel which must return the handler that is used for negotiating
|
||||
/// the protocol.
|
||||
/// - Returns: The protocol negotiation result.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@_spi(AsyncChannel)
|
||||
public func connect<Handler: NIOProtocolNegotiationHandler>(
|
||||
unixDomainSocketPath: String,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Handler>
|
||||
) async throws -> Handler.NegotiationResult {
|
||||
let address = try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
|
||||
return try await self.connect(
|
||||
to: address,
|
||||
channelInitializer: channelInitializer
|
||||
)
|
||||
}
|
||||
|
||||
/// Use the existing connected socket file descriptor.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - descriptor: The _Unix file descriptor_ representing the connected stream socket.
|
||||
/// - channelInitializer: A closure to initialize the channel which must return the handler that is used for negotiating
|
||||
/// the protocol.
|
||||
/// - Returns: The protocol negotiation result.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@_spi(AsyncChannel)
|
||||
public func withConnectedSocket<Handler: NIOProtocolNegotiationHandler>(
|
||||
_ socket: NIOBSDSocket.Handle,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Handler>
|
||||
) async throws -> Handler.NegotiationResult {
|
||||
let eventLoop = group.next()
|
||||
return try await self.withConnectedSocket(
|
||||
eventLoop: eventLoop,
|
||||
socket: socket,
|
||||
channelInitializer: channelInitializer,
|
||||
postRegisterTransformation: { handler, eventLoop in
|
||||
handler.protocolNegotiationResult.flatMap { result in
|
||||
result.resolve(on: eventLoop)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
private func initializeAndRegisterNewChannel<Handler: NIOProtocolNegotiationHandler>(
|
||||
eventLoop: EventLoop,
|
||||
protocolFamily: NIOBSDSocket.ProtocolFamily,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Handler>,
|
||||
_ body: @escaping (Channel) -> EventLoopFuture<Void>
|
||||
) -> EventLoopFuture<(Channel, Handler.NegotiationResult)> {
|
||||
self.initializeAndRegisterNewChannel(
|
||||
eventLoop: eventLoop,
|
||||
protocolFamily: protocolFamily,
|
||||
channelInitializer: channelInitializer,
|
||||
postRegisterTransformation: { handler, eventLoop in
|
||||
handler.protocolNegotiationResult.flatMap { result in
|
||||
result.resolve(on: eventLoop)
|
||||
}
|
||||
},
|
||||
body
|
||||
)
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
private func initializeAndRegisterChannel<Handler: NIOProtocolNegotiationHandler>(
|
||||
channel: SocketChannel,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Handler>,
|
||||
_ body: @escaping (Channel) -> EventLoopFuture<Void>
|
||||
) -> EventLoopFuture<Handler.NegotiationResult> {
|
||||
self.initializeAndRegisterChannel(
|
||||
channel: channel,
|
||||
channelInitializer: channelInitializer,
|
||||
registration: { channel in
|
||||
channel.registerAndDoSynchronously(body)
|
||||
},
|
||||
postRegisterTransformation: { handler, eventLoop in
|
||||
handler.protocolNegotiationResult.flatMap { result in
|
||||
result.resolve(on: channel.eventLoop)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// MARK: Async connect methods with arbitrary payload
|
||||
|
||||
extension ClientBootstrap {
|
||||
/// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - host: The host to connect to.
|
||||
/// - port: The port to connect to.
|
||||
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
|
||||
/// method.
|
||||
/// - Returns: The result of the channel initializer.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@_spi(AsyncChannel)
|
||||
public func connect<Output: Sendable>(
|
||||
host: String,
|
||||
port: Int,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
|
||||
) async throws -> Output {
|
||||
let eventLoop = self.group.next()
|
||||
return try await self.connect(
|
||||
host: host,
|
||||
port: port,
|
||||
eventLoop: eventLoop,
|
||||
channelInitializer: channelInitializer,
|
||||
postRegisterTransformation: { output, eventLoop in
|
||||
eventLoop.makeSucceededFuture(output)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// Specify the `address` to connect to for the TCP `Channel` that will be established.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - address: The address to connect to.
|
||||
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
|
||||
/// method.
|
||||
/// - Returns: The result of the channel initializer.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@_spi(AsyncChannel)
|
||||
public func connect<Output: Sendable>(
|
||||
to address: SocketAddress,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
|
||||
) async throws -> Output {
|
||||
let eventLoop = self.group.next()
|
||||
return try await self.initializeAndRegisterNewChannel(
|
||||
eventLoop: eventLoop,
|
||||
protocolFamily: address.protocol,
|
||||
channelInitializer: channelInitializer,
|
||||
postRegisterTransformation: { output, eventLoop in
|
||||
eventLoop.makeSucceededFuture(output)
|
||||
}, { channel in
|
||||
return self.connect(freshChannel: channel, address: address)
|
||||
}).get().1
|
||||
}
|
||||
|
||||
/// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - unixDomainSocketPath: The _Unix domain socket_ path to connect to.
|
||||
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
|
||||
/// method.
|
||||
/// - Returns: The result of the channel initializer.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@_spi(AsyncChannel)
|
||||
public func connect<Output: Sendable>(
|
||||
unixDomainSocketPath: String,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
|
||||
) async throws -> Output {
|
||||
let address = try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
|
||||
return try await self.connect(
|
||||
to: address,
|
||||
channelInitializer: channelInitializer
|
||||
)
|
||||
}
|
||||
|
||||
/// Use the existing connected socket file descriptor.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - descriptor: The _Unix file descriptor_ representing the connected stream socket.
|
||||
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
|
||||
/// method.
|
||||
/// - Returns: The result of the channel initializer.
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
@_spi(AsyncChannel)
|
||||
public func withConnectedSocket<Output: Sendable>(
|
||||
_ socket: NIOBSDSocket.Handle,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
|
||||
) async throws -> Output {
|
||||
let eventLoop = group.next()
|
||||
return try await self.withConnectedSocket(
|
||||
eventLoop: eventLoop,
|
||||
socket: socket,
|
||||
channelInitializer: channelInitializer,
|
||||
postRegisterTransformation: { output, eventLoop in
|
||||
eventLoop.makeSucceededFuture(output)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
func connect<ChannelInitializerResult, PostRegistrationTransformationResult>(
|
||||
host: String,
|
||||
port: Int,
|
||||
eventLoop: EventLoop,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
|
||||
postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture<PostRegistrationTransformationResult>
|
||||
) async throws -> PostRegistrationTransformationResult {
|
||||
let resolver = self.resolver ?? GetaddrinfoResolver(
|
||||
loop: eventLoop,
|
||||
aiSocktype: .stream,
|
||||
aiProtocol: .tcp
|
||||
)
|
||||
|
||||
let connector = HappyEyeballsConnector<PostRegistrationTransformationResult>(
|
||||
resolver: resolver,
|
||||
loop: eventLoop,
|
||||
host: host,
|
||||
port: port,
|
||||
connectTimeout: self.connectTimeout
|
||||
) { eventLoop, protocolFamily in
|
||||
return self.initializeAndRegisterNewChannel(
|
||||
eventLoop: eventLoop,
|
||||
protocolFamily: protocolFamily,
|
||||
channelInitializer: channelInitializer,
|
||||
postRegisterTransformation: postRegisterTransformation
|
||||
) {
|
||||
$0.eventLoop.makeSucceededFuture(())
|
||||
}
|
||||
}
|
||||
return try await connector.resolveAndConnect().map { $0.1 }.get()
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
private func withConnectedSocket<ChannelInitializerResult, PostRegistrationTransformationResult>(
|
||||
eventLoop: EventLoop,
|
||||
socket: NIOBSDSocket.Handle,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
|
||||
postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture<PostRegistrationTransformationResult>
|
||||
) async throws -> PostRegistrationTransformationResult {
|
||||
let channel = try SocketChannel(eventLoop: eventLoop as! SelectableEventLoop, socket: socket)
|
||||
|
||||
return try await self.initializeAndRegisterChannel(
|
||||
channel: channel,
|
||||
channelInitializer: channelInitializer,
|
||||
registration: { channel in
|
||||
let promise = eventLoop.makePromise(of: Void.self)
|
||||
channel.registerAlreadyConfigured0(promise: promise)
|
||||
return promise.futureResult
|
||||
},
|
||||
postRegisterTransformation: postRegisterTransformation
|
||||
).get()
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
private func initializeAndRegisterNewChannel<ChannelInitializerResult, PostRegistrationTransformationResult>(
|
||||
eventLoop: EventLoop,
|
||||
protocolFamily: NIOBSDSocket.ProtocolFamily,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
|
||||
postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture<PostRegistrationTransformationResult>,
|
||||
_ body: @escaping (Channel) -> EventLoopFuture<Void>
|
||||
) -> EventLoopFuture<(Channel, PostRegistrationTransformationResult)> {
|
||||
let channel: SocketChannel
|
||||
do {
|
||||
channel = try self.makeSocketChannel(eventLoop: eventLoop, protocolFamily: protocolFamily)
|
||||
} catch {
|
||||
return eventLoop.makeFailedFuture(error)
|
||||
}
|
||||
return self.initializeAndRegisterChannel(
|
||||
channel: channel,
|
||||
channelInitializer: channelInitializer,
|
||||
registration: { channel in
|
||||
channel.registerAndDoSynchronously(body)
|
||||
},
|
||||
postRegisterTransformation: postRegisterTransformation
|
||||
).map { (channel, $0) }
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
|
||||
private func initializeAndRegisterChannel<ChannelInitializerResult, PostRegistrationTransformationResult>(
|
||||
channel: SocketChannel,
|
||||
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
|
||||
registration: @escaping @Sendable (Channel) -> EventLoopFuture<Void>,
|
||||
postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture<PostRegistrationTransformationResult>
|
||||
) -> EventLoopFuture<PostRegistrationTransformationResult> {
|
||||
let channelInitializer = { channel in
|
||||
return self.channelInitializer(channel)
|
||||
.flatMap { channelInitializer(channel) }
|
||||
}
|
||||
let channelOptions = self._channelOptions
|
||||
let eventLoop = channel.eventLoop
|
||||
let bindTarget = self.bindTarget
|
||||
|
||||
@inline(__always)
|
||||
@Sendable
|
||||
func setupChannel() -> EventLoopFuture<PostRegistrationTransformationResult> {
|
||||
eventLoop.assertInEventLoop()
|
||||
return channelOptions
|
||||
.applyAllChannelOptions(to: channel)
|
||||
.flatMap {
|
||||
if let bindTarget = bindTarget {
|
||||
return channel
|
||||
.bind(to: bindTarget)
|
||||
.flatMap {
|
||||
channelInitializer(channel)
|
||||
}
|
||||
} else {
|
||||
return channelInitializer(channel)
|
||||
}
|
||||
}.flatMap { (result: ChannelInitializerResult) in
|
||||
eventLoop.assertInEventLoop()
|
||||
return registration(channel).map {
|
||||
result
|
||||
}
|
||||
}.flatMap { (result: ChannelInitializerResult) -> EventLoopFuture<PostRegistrationTransformationResult> in
|
||||
postRegisterTransformation(result, eventLoop)
|
||||
}.flatMapError { error in
|
||||
eventLoop.assertInEventLoop()
|
||||
channel.close0(error: error, mode: .all, promise: nil)
|
||||
return channel.eventLoop.makeFailedFuture(error)
|
||||
}
|
||||
}
|
||||
|
||||
if eventLoop.inEventLoop {
|
||||
return setupChannel()
|
||||
} else {
|
||||
return eventLoop.flatSubmit {
|
||||
setupChannel()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@available(*, unavailable)
|
||||
extension ClientBootstrap: Sendable {}
|
||||
|
||||
|
@ -1839,5 +2339,23 @@ public final class NIOPipeBootstrap {
|
|||
}
|
||||
}
|
||||
|
||||
extension NIOProtocolNegotiationResult {
|
||||
func resolve(on eventLoop: EventLoop) -> EventLoopFuture<NegotiationResult> {
|
||||
Self.resolve(on: eventLoop, result: self)
|
||||
}
|
||||
|
||||
static func resolve(on eventLoop: EventLoop, result: Self) -> EventLoopFuture<NegotiationResult> {
|
||||
switch result {
|
||||
case .finished(let negotiationResult):
|
||||
return eventLoop.makeSucceededFuture(negotiationResult)
|
||||
|
||||
case .deferredResult(let future):
|
||||
return future.flatMap { result in
|
||||
return resolve(on: eventLoop, result: result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@available(*, unavailable)
|
||||
extension NIOPipeBootstrap: Sendable {}
|
||||
|
|
|
@ -141,7 +141,10 @@ private struct TargetIterator: IteratorProtocol {
|
|||
///
|
||||
/// This class's private API is *not* thread-safe, and expects to be called from the
|
||||
/// event loop thread of the `loop` it is passed.
|
||||
internal class HappyEyeballsConnector {
|
||||
///
|
||||
/// The `ChannelBuilderResult` generic type can used to tunnel an arbitrary type
|
||||
/// from the `channelBuilderCallback` to the `resolve` methods return value.
|
||||
internal final class HappyEyeballsConnector<ChannelBuilderResult> {
|
||||
/// An enum for keeping track of connection state.
|
||||
private enum ConnectionState {
|
||||
/// Initial state. No work outstanding.
|
||||
|
@ -223,7 +226,7 @@ internal class HappyEyeballsConnector {
|
|||
/// than intended.
|
||||
///
|
||||
/// The channel builder callback takes an event loop and a protocol family as arguments.
|
||||
private let channelBuilderCallback: (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<Channel>
|
||||
private let channelBuilderCallback: (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<(Channel, ChannelBuilderResult)>
|
||||
|
||||
/// The amount of time to wait for an AAAA response to come in after a A response is
|
||||
/// received. By default this is 50ms.
|
||||
|
@ -250,7 +253,7 @@ internal class HappyEyeballsConnector {
|
|||
private var timeoutTask: Optional<Scheduled<Void>>
|
||||
|
||||
/// The promise that will hold the final connected channel.
|
||||
private let resolutionPromise: EventLoopPromise<Channel>
|
||||
private let resolutionPromise: EventLoopPromise<(Channel, ChannelBuilderResult)>
|
||||
|
||||
/// Our state machine state.
|
||||
private var state: ConnectionState
|
||||
|
@ -263,7 +266,7 @@ internal class HappyEyeballsConnector {
|
|||
///
|
||||
/// This is kept to ensure that we can clean up after ourselves once a connection succeeds,
|
||||
/// and throw away all pending connection attempts that are no longer needed.
|
||||
private var pendingConnections: [EventLoopFuture<Channel>] = []
|
||||
private var pendingConnections: [EventLoopFuture<(Channel, ChannelBuilderResult)>] = []
|
||||
|
||||
/// The number of DNS resolutions that have returned.
|
||||
///
|
||||
|
@ -274,6 +277,7 @@ internal class HappyEyeballsConnector {
|
|||
/// An object that holds any errors we encountered.
|
||||
private var error: NIOConnectionError
|
||||
|
||||
@inlinable
|
||||
init(resolver: Resolver,
|
||||
loop: EventLoop,
|
||||
host: String,
|
||||
|
@ -281,7 +285,7 @@ internal class HappyEyeballsConnector {
|
|||
connectTimeout: TimeAmount,
|
||||
resolutionDelay: TimeAmount = .milliseconds(50),
|
||||
connectionDelay: TimeAmount = .milliseconds(250),
|
||||
channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<Channel>) {
|
||||
channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<(Channel, ChannelBuilderResult)>) {
|
||||
self.resolver = resolver
|
||||
self.loop = loop
|
||||
self.host = host
|
||||
|
@ -303,10 +307,34 @@ internal class HappyEyeballsConnector {
|
|||
self.connectionDelay = connectionDelay
|
||||
}
|
||||
|
||||
@inlinable
|
||||
convenience init(
|
||||
resolver: Resolver,
|
||||
loop: EventLoop,
|
||||
host: String,
|
||||
port: Int,
|
||||
connectTimeout: TimeAmount,
|
||||
resolutionDelay: TimeAmount = .milliseconds(50),
|
||||
connectionDelay: TimeAmount = .milliseconds(250),
|
||||
channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<Channel>
|
||||
) where ChannelBuilderResult == Void {
|
||||
self.init(
|
||||
resolver: resolver,
|
||||
loop: loop,
|
||||
host: host,
|
||||
port: port,
|
||||
connectTimeout: connectTimeout,
|
||||
resolutionDelay: resolutionDelay,
|
||||
connectionDelay: connectionDelay) { loop, protocolFamily in
|
||||
channelBuilderCallback(loop, protocolFamily).map { ($0, ()) }
|
||||
}
|
||||
}
|
||||
|
||||
/// Initiate a DNS resolution attempt using Happy Eyeballs 2.
|
||||
///
|
||||
/// returns: An `EventLoopFuture` that fires with a connected `Channel`.
|
||||
public func resolveAndConnect() -> EventLoopFuture<Channel> {
|
||||
@inlinable
|
||||
func resolveAndConnect() -> EventLoopFuture<(Channel, ChannelBuilderResult)> {
|
||||
// We dispatch ourselves onto the event loop, rather than do all the rest of our processing from outside it.
|
||||
self.loop.execute {
|
||||
self.timeoutTask = self.loop.scheduleTask(in: self.connectTimeout) { self.processInput(.connectTimeoutElapsed) }
|
||||
|
@ -315,6 +343,14 @@ internal class HappyEyeballsConnector {
|
|||
return resolutionPromise.futureResult
|
||||
}
|
||||
|
||||
/// Initiate a DNS resolution attempt using Happy Eyeballs 2.
|
||||
///
|
||||
/// returns: An `EventLoopFuture` that fires with a connected `Channel`.
|
||||
@inlinable
|
||||
func resolveAndConnect() -> EventLoopFuture<Channel> where ChannelBuilderResult == Void {
|
||||
self.resolveAndConnect().map { $0.0 }
|
||||
}
|
||||
|
||||
/// Spin the state machine.
|
||||
///
|
||||
/// - parameters:
|
||||
|
@ -540,11 +576,11 @@ internal class HappyEyeballsConnector {
|
|||
let channelFuture = channelBuilderCallback(self.loop, target.protocol)
|
||||
pendingConnections.append(channelFuture)
|
||||
|
||||
channelFuture.whenSuccess { channel in
|
||||
channelFuture.whenSuccess { (channel, result) in
|
||||
// If we are in the complete state then we want to abandon this channel. Otherwise, begin
|
||||
// connecting.
|
||||
if case .complete = self.state {
|
||||
self.pendingConnections.remove(element: channelFuture)
|
||||
self.pendingConnections.removeAll { $0 === channelFuture }
|
||||
channel.close(promise: nil)
|
||||
} else {
|
||||
channel.connect(to: target).map {
|
||||
|
@ -552,13 +588,13 @@ internal class HappyEyeballsConnector {
|
|||
// Otherwise, fire the channel connected event. Either way we don't want the channel future to
|
||||
// be in our list of pending connections, so we don't either double close or close the connection
|
||||
// we want to use.
|
||||
self.pendingConnections.remove(element: channelFuture)
|
||||
self.pendingConnections.removeAll { $0 === channelFuture }
|
||||
|
||||
if case .complete = self.state {
|
||||
channel.close(promise: nil)
|
||||
} else {
|
||||
self.processInput(.connectSuccess)
|
||||
self.resolutionPromise.succeed(channel)
|
||||
self.resolutionPromise.succeed((channel, result))
|
||||
}
|
||||
}.whenFailure { err in
|
||||
// The connection attempt failed. If we're in the complete state then there's nothing
|
||||
|
@ -567,7 +603,7 @@ internal class HappyEyeballsConnector {
|
|||
assert(self.pendingConnections.firstIndex { $0 === channelFuture } == nil, "failed but was still in pending connections")
|
||||
} else {
|
||||
self.error.connectionErrors.append(SingleConnectionFailure(target: target, error: err))
|
||||
self.pendingConnections.remove(element: channelFuture)
|
||||
self.pendingConnections.removeAll { $0 === channelFuture }
|
||||
self.processInput(.connectFailed)
|
||||
}
|
||||
}
|
||||
|
@ -575,7 +611,7 @@ internal class HappyEyeballsConnector {
|
|||
}
|
||||
channelFuture.whenFailure { error in
|
||||
self.error.connectionErrors.append(SingleConnectionFailure(target: target, error: error))
|
||||
self.pendingConnections.remove(element: channelFuture)
|
||||
self.pendingConnections.removeAll { $0 === channelFuture }
|
||||
self.processInput(.connectFailed)
|
||||
}
|
||||
}
|
||||
|
@ -607,7 +643,7 @@ internal class HappyEyeballsConnector {
|
|||
let connections = self.pendingConnections
|
||||
self.pendingConnections = []
|
||||
for connection in connections {
|
||||
connection.whenSuccess { channel in channel.close(promise: nil) }
|
||||
connection.whenSuccess { (channel, _) in channel.close(promise: nil) }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import NIOConcurrencyHelpers
|
|||
import XCTest
|
||||
@_spi(AsyncChannel) import NIOTLS
|
||||
|
||||
private final class LineDelimiterDecoder: ByteToMessageDecoder {
|
||||
private final class LineDelimiterCoder: ByteToMessageDecoder, MessageToByteEncoder {
|
||||
private let newLine = "\n".utf8.first!
|
||||
|
||||
typealias InboundIn = ByteBuffer
|
||||
|
@ -33,43 +33,100 @@ private final class LineDelimiterDecoder: ByteToMessageDecoder {
|
|||
}
|
||||
return .needMoreData
|
||||
}
|
||||
|
||||
func encode(data: ByteBuffer, out: inout ByteBuffer) throws {
|
||||
out.writeImmutableBuffer(data)
|
||||
out.writeString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
private final class TLSUserEventHandler: ChannelInboundHandler {
|
||||
private final class TLSUserEventHandler: ChannelInboundHandler, RemovableChannelHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
typealias InboundOut = ByteBuffer
|
||||
enum ALPN: String {
|
||||
case string
|
||||
case byte
|
||||
case unknown
|
||||
}
|
||||
|
||||
private var proposedALPN: ALPN?
|
||||
|
||||
init(
|
||||
proposedALPN: ALPN? = nil
|
||||
) {
|
||||
self.proposedALPN = proposedALPN
|
||||
}
|
||||
|
||||
func handlerAdded(context: ChannelHandlerContext) {
|
||||
guard context.channel.isActive else {
|
||||
return
|
||||
}
|
||||
|
||||
if let proposedALPN = self.proposedALPN {
|
||||
self.proposedALPN = nil
|
||||
context.writeAndFlush(.init(ByteBuffer(string: "negotiate-alpn:\(proposedALPN.rawValue)")), promise: nil)
|
||||
}
|
||||
context.fireChannelActive()
|
||||
}
|
||||
|
||||
func channelActive(context: ChannelHandlerContext) {
|
||||
if let proposedALPN = self.proposedALPN {
|
||||
context.writeAndFlush(.init(ByteBuffer(string: "negotiate-alpn:\(proposedALPN.rawValue)")), promise: nil)
|
||||
}
|
||||
context.fireChannelActive()
|
||||
}
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let buffer = self.unwrapInboundIn(data)
|
||||
let alpn = String(buffer: buffer)
|
||||
let string = String(buffer: buffer)
|
||||
|
||||
if alpn.hasPrefix("alpn:") {
|
||||
context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: String(alpn.dropFirst(5))))
|
||||
if string.hasPrefix("negotiate-alpn:") {
|
||||
let alpn = String(string.dropFirst(15))
|
||||
context.writeAndFlush(.init(ByteBuffer(string: "alpn:\(alpn)")), promise: nil)
|
||||
context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: alpn))
|
||||
context.pipeline.removeHandler(self, promise: nil)
|
||||
} else if string.hasPrefix("alpn:") {
|
||||
context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: String(string.dropFirst(5))))
|
||||
context.pipeline.removeHandler(self, promise: nil)
|
||||
} else {
|
||||
context.fireChannelRead(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private final class ByteBufferToStringHandler: ChannelInboundHandler {
|
||||
private final class ByteBufferToStringHandler: ChannelDuplexHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
typealias InboundOut = String
|
||||
typealias OutboundIn = String
|
||||
typealias OutboundOut = ByteBuffer
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let buffer = self.unwrapInboundIn(data)
|
||||
context.fireChannelRead(self.wrapInboundOut(String(buffer: buffer)))
|
||||
}
|
||||
|
||||
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
|
||||
let buffer = ByteBuffer(string: self.unwrapOutboundIn(data))
|
||||
context.write(.init(buffer), promise: promise)
|
||||
}
|
||||
}
|
||||
|
||||
private final class ByteBufferToByteHandler: ChannelInboundHandler {
|
||||
private final class ByteBufferToByteHandler: ChannelDuplexHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
typealias InboundOut = UInt8
|
||||
typealias OutboundIn = UInt8
|
||||
typealias OutboundOut = ByteBuffer
|
||||
|
||||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
var buffer = self.unwrapInboundIn(data)
|
||||
let byte = buffer.readInteger(as: UInt8.self)!
|
||||
context.fireChannelRead(self.wrapInboundOut(byte))
|
||||
}
|
||||
|
||||
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
|
||||
let buffer = ByteBuffer(integer: self.unwrapOutboundIn(data))
|
||||
context.write(.init(buffer), promise: promise)
|
||||
}
|
||||
}
|
||||
|
||||
final class AsyncChannelBootstrapTests: XCTestCase {
|
||||
|
@ -94,7 +151,8 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
.childChannelOption(ChannelOptions.autoRead, value: true)
|
||||
.childChannelInitializer { channel in
|
||||
channel.eventLoop.makeCompletedFuture {
|
||||
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
|
||||
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
|
||||
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
|
||||
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
|
||||
}
|
||||
}
|
||||
|
@ -120,7 +178,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
}
|
||||
|
||||
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
|
||||
stringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
|
||||
try await stringChannel.outboundWriter.write("hello")
|
||||
|
||||
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
|
||||
|
||||
|
@ -138,7 +196,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
.childChannelOption(ChannelOptions.autoRead, value: true)
|
||||
.childChannelInitializer { channel in
|
||||
channel.eventLoop.makeCompletedFuture {
|
||||
try self.makeProtocolNegotiationChildChannel(channel: channel)
|
||||
try self.configureProtocolNegotiationHandlers(channel: channel)
|
||||
}
|
||||
}
|
||||
.bind(
|
||||
|
@ -149,7 +207,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
|
||||
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
|
||||
var iterator = stream.makeAsyncIterator()
|
||||
var serverIterator = stream.makeAsyncIterator()
|
||||
|
||||
group.addTask {
|
||||
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||
|
@ -170,26 +228,34 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
|
||||
let stringNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation(
|
||||
eventLoopGroup: eventLoopGroup,
|
||||
port: channel.channel.localAddress!.port!,
|
||||
proposedALPN: .string
|
||||
)
|
||||
switch stringNegotiationResult {
|
||||
case .string(let stringChannel):
|
||||
// This is the actual content
|
||||
try await stringChannel.outboundWriter.write("hello")
|
||||
await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello"))
|
||||
case .byte:
|
||||
preconditionFailure()
|
||||
}
|
||||
|
||||
// This is for negotiating the protocol
|
||||
stringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
|
||||
|
||||
// This is the actual content
|
||||
stringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
|
||||
|
||||
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
|
||||
|
||||
let byteChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
|
||||
|
||||
// This is for negotiating the protocol
|
||||
byteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
|
||||
|
||||
// This is the actual content
|
||||
byteChannel.write(.init(ByteBuffer(integer: UInt8(8))), promise: nil)
|
||||
byteChannel.writeAndFlush(.init(ByteBuffer(string: "\n")), promise: nil)
|
||||
|
||||
await XCTAsyncAssertEqual(await iterator.next(), .byte(8))
|
||||
let byteNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation(
|
||||
eventLoopGroup: eventLoopGroup,
|
||||
port: channel.channel.localAddress!.port!,
|
||||
proposedALPN: .byte
|
||||
)
|
||||
switch byteNegotiationResult {
|
||||
case .string:
|
||||
preconditionFailure()
|
||||
case .byte(let byteChannel):
|
||||
// This is the actual content
|
||||
try await byteChannel.outboundWriter.write(UInt8(8))
|
||||
await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8))
|
||||
}
|
||||
|
||||
group.cancelAll()
|
||||
}
|
||||
|
@ -205,7 +271,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
.childChannelOption(ChannelOptions.autoRead, value: true)
|
||||
.childChannelInitializer { channel in
|
||||
channel.eventLoop.makeCompletedFuture {
|
||||
try self.makeNestedProtocolNegotiationChildChannel(channel: channel)
|
||||
try self.configureNestedProtocolNegotiationHandlers(channel: channel)
|
||||
}
|
||||
}
|
||||
.bind(
|
||||
|
@ -216,7 +282,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
|
||||
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
|
||||
var iterator = stream.makeAsyncIterator()
|
||||
var serverIterator = stream.makeAsyncIterator()
|
||||
|
||||
group.addTask {
|
||||
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||
|
@ -237,59 +303,65 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
let stringStringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
|
||||
let stringStringNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation(
|
||||
eventLoopGroup: eventLoopGroup,
|
||||
port: channel.channel.localAddress!.port!,
|
||||
proposedOuterALPN: .string,
|
||||
proposedInnerALPN: .string
|
||||
)
|
||||
switch stringStringNegotiationResult {
|
||||
case .string(let stringChannel):
|
||||
// This is the actual content
|
||||
try await stringChannel.outboundWriter.write("hello")
|
||||
await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello"))
|
||||
case .byte:
|
||||
preconditionFailure()
|
||||
}
|
||||
|
||||
// This is for negotiating the protocol
|
||||
stringStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
|
||||
let byteStringNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation(
|
||||
eventLoopGroup: eventLoopGroup,
|
||||
port: channel.channel.localAddress!.port!,
|
||||
proposedOuterALPN: .byte,
|
||||
proposedInnerALPN: .string
|
||||
)
|
||||
switch byteStringNegotiationResult {
|
||||
case .string(let stringChannel):
|
||||
// This is the actual content
|
||||
try await stringChannel.outboundWriter.write("hello")
|
||||
await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello"))
|
||||
case .byte:
|
||||
preconditionFailure()
|
||||
}
|
||||
|
||||
// This is for negotiating the nested protocol
|
||||
stringStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
|
||||
let byteByteNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation(
|
||||
eventLoopGroup: eventLoopGroup,
|
||||
port: channel.channel.localAddress!.port!,
|
||||
proposedOuterALPN: .byte,
|
||||
proposedInnerALPN: .byte
|
||||
)
|
||||
switch byteByteNegotiationResult {
|
||||
case .string:
|
||||
preconditionFailure()
|
||||
case .byte(let byteChannel):
|
||||
// This is the actual content
|
||||
try await byteChannel.outboundWriter.write(UInt8(8))
|
||||
await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8))
|
||||
}
|
||||
|
||||
// This is the actual content
|
||||
stringStringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
|
||||
|
||||
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
|
||||
|
||||
let byteByteChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
|
||||
|
||||
// This is for negotiating the protocol
|
||||
byteByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
|
||||
|
||||
// This is for negotiating the nested protocol
|
||||
byteByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
|
||||
|
||||
// This is the actual content
|
||||
byteByteChannel.write(.init(ByteBuffer(integer: UInt8(8))), promise: nil)
|
||||
byteByteChannel.writeAndFlush(.init(ByteBuffer(string: "\n")), promise: nil)
|
||||
|
||||
await XCTAsyncAssertEqual(await iterator.next(), .byte(8))
|
||||
|
||||
let stringByteChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
|
||||
|
||||
// This is for negotiating the protocol
|
||||
stringByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
|
||||
|
||||
// This is for negotiating the nested protocol
|
||||
stringByteChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
|
||||
|
||||
// This is the actual content
|
||||
stringByteChannel.write(.init(ByteBuffer(integer: UInt8(8))), promise: nil)
|
||||
stringByteChannel.writeAndFlush(.init(ByteBuffer(string: "\n")), promise: nil)
|
||||
|
||||
await XCTAsyncAssertEqual(await iterator.next(), .byte(8))
|
||||
|
||||
let byteStringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
|
||||
|
||||
// This is for negotiating the protocol
|
||||
byteStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:byte\n")), promise: nil)
|
||||
|
||||
// This is for negotiating the nested protocol
|
||||
byteStringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
|
||||
|
||||
// This is the actual content
|
||||
byteStringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
|
||||
|
||||
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
|
||||
let stringByteNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation(
|
||||
eventLoopGroup: eventLoopGroup,
|
||||
port: channel.channel.localAddress!.port!,
|
||||
proposedOuterALPN: .string,
|
||||
proposedInnerALPN: .byte
|
||||
)
|
||||
switch stringByteNegotiationResult {
|
||||
case .string:
|
||||
preconditionFailure()
|
||||
case .byte(let byteChannel):
|
||||
// This is the actual content
|
||||
try await byteChannel.outboundWriter.write(UInt8(8))
|
||||
await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8))
|
||||
}
|
||||
|
||||
group.cancelAll()
|
||||
}
|
||||
|
@ -328,7 +400,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
.childChannelOption(ChannelOptions.autoRead, value: true)
|
||||
.childChannelInitializer { channel in
|
||||
channel.eventLoop.makeCompletedFuture {
|
||||
try self.makeProtocolNegotiationChildChannel(channel: channel)
|
||||
try self.configureProtocolNegotiationHandlers(channel: channel)
|
||||
}
|
||||
}
|
||||
.bind(
|
||||
|
@ -339,7 +411,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
|
||||
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
|
||||
var iterator = stream.makeAsyncIterator()
|
||||
var serverIterator = stream.makeAsyncIterator()
|
||||
|
||||
group.addTask {
|
||||
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||
|
@ -360,21 +432,30 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
let unknownChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
|
||||
await XCTAssertThrowsError(
|
||||
try await self.makeClientChannelWithProtocolNegotiation(
|
||||
eventLoopGroup: eventLoopGroup,
|
||||
port: channel.channel.localAddress!.port!,
|
||||
proposedALPN: .unknown
|
||||
)
|
||||
) { error in
|
||||
XCTAssertTrue(error is ProtocolNegotiationError)
|
||||
}
|
||||
|
||||
// This is for negotiating the protocol
|
||||
unknownChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:unknown\n")), promise: nil)
|
||||
|
||||
// Checking that we can still create new connections afterwards
|
||||
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
|
||||
|
||||
// This is for negotiating the protocol
|
||||
stringChannel.writeAndFlush(.init(ByteBuffer(string: "alpn:string\n")), promise: nil)
|
||||
|
||||
// This is the actual content
|
||||
stringChannel.writeAndFlush(.init(ByteBuffer(string: "hello\n")), promise: nil)
|
||||
|
||||
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
|
||||
// Let's check that we can still open a new connection
|
||||
let stringNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation(
|
||||
eventLoopGroup: eventLoopGroup,
|
||||
port: channel.channel.localAddress!.port!,
|
||||
proposedALPN: .string
|
||||
)
|
||||
switch stringNegotiationResult {
|
||||
case .string(let stringChannel):
|
||||
// This is the actual content
|
||||
try await stringChannel.outboundWriter.write("hello")
|
||||
await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello"))
|
||||
case .byte:
|
||||
preconditionFailure()
|
||||
}
|
||||
|
||||
let failedInboundChannel = channels.withLockedValue { channels -> Channel in
|
||||
XCTAssertEqual(channels.count, 2)
|
||||
|
@ -391,55 +472,108 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
|
||||
// MARK: - Test Helpers
|
||||
|
||||
private func makeClientChannel(eventLoopGroup: EventLoopGroup, port: Int) async throws -> Channel {
|
||||
private func makeClientChannel(eventLoopGroup: EventLoopGroup, port: Int) async throws -> NIOAsyncChannel<String, String> {
|
||||
return try await ClientBootstrap(group: eventLoopGroup)
|
||||
.channelInitializer { channel in
|
||||
channel.eventLoop.makeCompletedFuture {
|
||||
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
|
||||
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
|
||||
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
|
||||
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
|
||||
}
|
||||
}
|
||||
.connect(to: .init(ipAddress: "127.0.0.1", port: port))
|
||||
.get()
|
||||
.connect(
|
||||
to: .init(ipAddress: "127.0.0.1", port: port),
|
||||
inboundType: String.self,
|
||||
outboundType: String.self
|
||||
)
|
||||
}
|
||||
|
||||
private func makeProtocolNegotiationChildChannel(channel: Channel) throws {
|
||||
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
|
||||
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler())
|
||||
try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
|
||||
}
|
||||
|
||||
private func makeNestedProtocolNegotiationChildChannel(channel: Channel) throws {
|
||||
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterDecoder()))
|
||||
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler())
|
||||
try channel.pipeline.syncOperations.addHandler(
|
||||
NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: channel.eventLoop) { alpnResult, channel in
|
||||
switch alpnResult {
|
||||
case .negotiated(let alpn):
|
||||
switch alpn {
|
||||
case "string":
|
||||
return channel.eventLoop.makeCompletedFuture {
|
||||
let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
|
||||
|
||||
return NIOProtocolNegotiationResult.deferredResult(negotiationFuture)
|
||||
}
|
||||
case "byte":
|
||||
return channel.eventLoop.makeCompletedFuture {
|
||||
let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
|
||||
|
||||
return NIOProtocolNegotiationResult.deferredResult(negotiationFuture)
|
||||
}
|
||||
default:
|
||||
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
|
||||
}
|
||||
case .fallback:
|
||||
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
|
||||
private func makeClientChannelWithProtocolNegotiation(
|
||||
eventLoopGroup: EventLoopGroup,
|
||||
port: Int,
|
||||
proposedALPN: TLSUserEventHandler.ALPN
|
||||
) async throws -> NegotiationResult {
|
||||
return try await ClientBootstrap(group: eventLoopGroup)
|
||||
.connect(
|
||||
to: .init(ipAddress: "127.0.0.1", port: port)
|
||||
) { channel in
|
||||
return channel.eventLoop.makeCompletedFuture {
|
||||
return try self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: proposedALPN)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func makeClientChannelWithNestedProtocolNegotiation(
|
||||
eventLoopGroup: EventLoopGroup,
|
||||
port: Int,
|
||||
proposedOuterALPN: TLSUserEventHandler.ALPN,
|
||||
proposedInnerALPN: TLSUserEventHandler.ALPN
|
||||
) async throws -> NegotiationResult {
|
||||
return try await ClientBootstrap(group: eventLoopGroup)
|
||||
.connect(
|
||||
to: .init(ipAddress: "127.0.0.1", port: port)
|
||||
) { channel in
|
||||
return channel.eventLoop.makeCompletedFuture {
|
||||
try self.configureNestedProtocolNegotiationHandlers(
|
||||
channel: channel,
|
||||
proposedOuterALPN: proposedOuterALPN,
|
||||
proposedInnerALPN: proposedInnerALPN
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> EventLoopFuture<NIOProtocolNegotiationResult<NegotiationResult>> {
|
||||
private func configureProtocolNegotiationHandlers(
|
||||
channel: Channel,
|
||||
proposedALPN: TLSUserEventHandler.ALPN? = nil
|
||||
) throws -> NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> {
|
||||
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
|
||||
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
|
||||
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedALPN))
|
||||
return try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
private func configureNestedProtocolNegotiationHandlers(
|
||||
channel: Channel,
|
||||
proposedOuterALPN: TLSUserEventHandler.ALPN? = nil,
|
||||
proposedInnerALPN: TLSUserEventHandler.ALPN? = nil
|
||||
) throws -> NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> {
|
||||
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
|
||||
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
|
||||
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedOuterALPN))
|
||||
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: channel.eventLoop) { alpnResult, channel in
|
||||
switch alpnResult {
|
||||
case .negotiated(let alpn):
|
||||
switch alpn {
|
||||
case "string":
|
||||
return channel.eventLoop.makeCompletedFuture {
|
||||
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN))
|
||||
let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
|
||||
|
||||
return NIOProtocolNegotiationResult.deferredResult(negotiationFuture.protocolNegotiationResult)
|
||||
}
|
||||
case "byte":
|
||||
return channel.eventLoop.makeCompletedFuture {
|
||||
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN))
|
||||
let negotiationHandler = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
|
||||
|
||||
return NIOProtocolNegotiationResult.deferredResult(negotiationHandler.protocolNegotiationResult)
|
||||
}
|
||||
default:
|
||||
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
|
||||
}
|
||||
case .fallback:
|
||||
return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError())
|
||||
}
|
||||
}
|
||||
try channel.pipeline.syncOperations.addHandler(negotiationHandler)
|
||||
return negotiationHandler
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> {
|
||||
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult>(eventLoop: channel.eventLoop) { alpnResult, channel in
|
||||
switch alpnResult {
|
||||
case .negotiated(let alpn):
|
||||
|
@ -478,7 +612,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
|
|||
}
|
||||
|
||||
try channel.pipeline.syncOperations.addHandler(negotiationHandler)
|
||||
return negotiationHandler.protocolNegotiationResult
|
||||
return negotiationHandler
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -499,3 +633,19 @@ private func XCTAsyncAssertEqual<Element: Equatable>(_ lhs: @autoclosure () asyn
|
|||
let rhsResult = try await rhs()
|
||||
XCTAssertEqual(lhsResult, rhsResult, file: file, line: line)
|
||||
}
|
||||
|
||||
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
|
||||
private func XCTAsyncAssertThrowsError<T>(
|
||||
_ expression: @autoclosure () async throws -> T,
|
||||
_ message: @autoclosure () -> String = "",
|
||||
file: StaticString = #filePath,
|
||||
line: UInt = #line,
|
||||
_ errorHandler: (_ error: Error) -> Void = { _ in }
|
||||
) async {
|
||||
do {
|
||||
_ = try await expression()
|
||||
XCTFail(message(), file: file, line: line)
|
||||
} catch {
|
||||
errorHandler(error)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -226,10 +226,12 @@ private func defaultChannelBuilder(loop: EventLoop, family: NIOBSDSocket.Protoco
|
|||
return loop.makeSucceededFuture(channel)
|
||||
}
|
||||
|
||||
private func buildEyeballer(host: String,
|
||||
port: Int,
|
||||
connectTimeout: TimeAmount = .seconds(10),
|
||||
channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<Channel> = defaultChannelBuilder) -> (eyeballer: HappyEyeballsConnector, resolver: DummyResolver, loop: EmbeddedEventLoop) {
|
||||
private func buildEyeballer(
|
||||
host: String,
|
||||
port: Int,
|
||||
connectTimeout: TimeAmount = .seconds(10),
|
||||
channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<Channel> = defaultChannelBuilder
|
||||
) -> (eyeballer: HappyEyeballsConnector<Void>, resolver: DummyResolver, loop: EmbeddedEventLoop) {
|
||||
let loop = EmbeddedEventLoop()
|
||||
let resolver = DummyResolver(loop: loop)
|
||||
let eyeballer = HappyEyeballsConnector(resolver: resolver,
|
||||
|
|
Loading…
Reference in New Issue