diff --git a/Sources/NIO/Channel.swift b/Sources/NIO/Channel.swift index da414047..3b706c7d 100644 --- a/Sources/NIO/Channel.swift +++ b/Sources/NIO/Channel.swift @@ -531,6 +531,9 @@ private extension FileRegion { /// - note: All operations on `SocketChannel` are thread-safe. final class SocketChannel: BaseSocketChannel { + private var connectTimeout = TimeAmount.seconds(10) + private var connectTimeoutScheduled: Scheduled? + init(eventLoop: SelectableEventLoop, protocolFamily: Int32) throws { let socket = try Socket(protocolFamily: protocolFamily) do { @@ -542,6 +545,26 @@ final class SocketChannel: BaseSocketChannel { try super.init(socket: socket, eventLoop: eventLoop) } + override fileprivate func setOption0(option: T, value: T.OptionType) throws { + assert(eventLoop.inEventLoop) + switch option { + case _ as ConnectTimeoutOption: + connectTimeout = value as! TimeAmount + default: + try super.setOption0(option: option, value: value) + } + } + + override fileprivate func getOption0(option: T) throws -> T.OptionType { + assert(eventLoop.inEventLoop) + switch option { + case _ as ConnectTimeoutOption: + return connectTimeout as! T.OptionType + default: + return try super.getOption0(option: option) + } + } + public override func registrationFor(interested: IOEvent) -> NIORegistration { return .socketChannel(self, interested) } @@ -632,13 +655,36 @@ final class SocketChannel: BaseSocketChannel { } override fileprivate func connectSocket(to address: SocketAddress) throws -> Bool { - return try self.socket.connect(to: address) + if try self.socket.connect(to: address) { + return true + } + let timeout = connectTimeout + connectTimeoutScheduled = eventLoop.scheduleTask(in: timeout) { () -> (Void) in + if self.pendingConnect != nil { + // The connection was still not established, close the Channel which will also fail the pending promise. + self.close0(error: ChannelError.connectTimeout(timeout), promise: nil) + } + } + return false } override fileprivate func finishConnectSocket() throws { + if let scheduled = connectTimeoutScheduled { + // Connection established so cancel the previous scheduled timeout. + connectTimeoutScheduled = nil + scheduled.cancel() + } try self.socket.finishConnect() becomeActive0() } + + override func close0(error: Error, promise: EventLoopPromise?) { + if let timeout = connectTimeoutScheduled { + connectTimeoutScheduled = nil + timeout.cancel() + } + super.close0(error: error, promise: promise) + } } /// A `Channel` for a server socket. @@ -912,7 +958,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { private let pendingWrites: PendingWritesManager fileprivate var readPending = false private var neverRegistered = true - private var pendingConnect: EventLoopPromise? + fileprivate var pendingConnect: EventLoopPromise? private let closePromise: EventLoopPromise private var active: Atomic = Atomic(value: false) public var isActive: Bool { @@ -1152,7 +1198,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { } } - public final func close0(error: Error, promise: EventLoopPromise?) { + public func close0(error: Error, promise: EventLoopPromise?) { assert(eventLoop.inEventLoop) if closed { @@ -1421,7 +1467,10 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { public enum ChannelError: Error { /// Tried to connect on a `Channel` that is already connecting. case connectPending - + + /// Connect operation timed out + case connectTimeout(TimeAmount) + /// Unsupported operation triggered on a `Channel`. For example `connect` on a `ServerSocketChannel`. case operationUnsupported @@ -1435,3 +1484,25 @@ public enum ChannelError: Error { /// open locally. case eof } + +extension ChannelError: Equatable { + public static func ==(lhs: ChannelError, rhs: ChannelError) -> Bool { + switch (lhs, rhs) { + case (.connectPending, .connectPending): + return true + case (.connectTimeout(_), .connectTimeout(_)): + return true + case (.operationUnsupported, .operationUnsupported): + return true + case (.ioOnClosedChannel, .ioOnClosedChannel): + return true + case (.alreadyClosed, .alreadyClosed): + return true + case (.eof, .eof): + return true + default: + return false + } + } +} + diff --git a/Sources/NIO/ChannelOption.swift b/Sources/NIO/ChannelOption.swift index a96ab402..d5b63d50 100644 --- a/Sources/NIO/ChannelOption.swift +++ b/Sources/NIO/ChannelOption.swift @@ -139,6 +139,15 @@ public enum WriteBufferWaterMarkOption: ChannelOption { case const(()) } +/// `ConnectTimeoutOption` allows to configure the `TimeAmount` after which a connect will fail if it was not established in the meantime. +public enum ConnectTimeoutOption: ChannelOption { + public typealias AssociatedValueType = () + public typealias OptionType = TimeAmount + + case const(()) +} + + /// Provides `ChannelOption`s to be used with a `Channel`, `Bootstrap` or `ServerBootstrap`. public struct ChannelOptions { /// - seealso: `SocketOption`. @@ -164,4 +173,7 @@ public struct ChannelOptions { /// - seealso: `WriteBufferWaterMarkOption`. public static let writeBufferWaterMark = WriteBufferWaterMarkOption.const(()) + + /// - seealso: `ConnectTimeoutOption`. + public static let connectTimeout = ConnectTimeoutOption.const(()) } diff --git a/Tests/NIOTests/ChannelTests+XCTest.swift b/Tests/NIOTests/ChannelTests+XCTest.swift index a9686a53..4cd7cdf1 100644 --- a/Tests/NIOTests/ChannelTests+XCTest.swift +++ b/Tests/NIOTests/ChannelTests+XCTest.swift @@ -46,6 +46,7 @@ extension ChannelTests { ("testPendingWritesCloseDuringVectorWrite", testPendingWritesCloseDuringVectorWrite), ("testPendingWritesMoreThanWritevIOVectorLimit", testPendingWritesMoreThanWritevIOVectorLimit), ("testPendingWritesIsHappyWhenSendfileReturnsWouldBlockButWroteFully", testPendingWritesIsHappyWhenSendfileReturnsWouldBlockButWroteFully), + ("testConnectTimeout", testConnectTimeout), ] } } diff --git a/Tests/NIOTests/ChannelTests.swift b/Tests/NIOTests/ChannelTests.swift index 18deb265..3033efae 100644 --- a/Tests/NIOTests/ChannelTests.swift +++ b/Tests/NIOTests/ChannelTests.swift @@ -1010,4 +1010,25 @@ public class ChannelTests: XCTestCase { XCTAssertEqual(.writtenCompletely, result) } } + + func testConnectTimeout() throws { + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { + try! group.syncShutdownGracefully() + } + + do { + // This must throw as 198.51.100.254 is reserved for documentation only + _ = try ClientBootstrap(group: group) + .channelOption(option: ChannelOptions.connectTimeout, value: .milliseconds(10)) + .connect(to: SocketAddress.newAddressResolving(host: "198.51.100.254", port: 65535)).wait() + XCTFail() + } catch let err as ChannelError { + if case .connectTimeout(_) = err { + // expected, sadly there is no "if not case" + } else { + XCTFail() + } + } + } }