Allow to specify a connect timeout

This commit is contained in:
Norman Maurer 2017-12-20 15:26:39 +01:00
parent 9015d08310
commit 8b9a31536e
4 changed files with 109 additions and 4 deletions

View File

@ -531,6 +531,9 @@ private extension FileRegion {
/// - note: All operations on `SocketChannel` are thread-safe.
final class SocketChannel: BaseSocketChannel<Socket> {
private var connectTimeout = TimeAmount.seconds(10)
private var connectTimeoutScheduled: Scheduled<Void>?
init(eventLoop: SelectableEventLoop, protocolFamily: Int32) throws {
let socket = try Socket(protocolFamily: protocolFamily)
do {
@ -542,6 +545,26 @@ final class SocketChannel: BaseSocketChannel<Socket> {
try super.init(socket: socket, eventLoop: eventLoop)
}
override fileprivate func setOption0<T: ChannelOption>(option: T, value: T.OptionType) throws {
assert(eventLoop.inEventLoop)
switch option {
case _ as ConnectTimeoutOption:
connectTimeout = value as! TimeAmount
default:
try super.setOption0(option: option, value: value)
}
}
override fileprivate func getOption0<T: ChannelOption>(option: T) throws -> T.OptionType {
assert(eventLoop.inEventLoop)
switch option {
case _ as ConnectTimeoutOption:
return connectTimeout as! T.OptionType
default:
return try super.getOption0(option: option)
}
}
public override func registrationFor(interested: IOEvent) -> NIORegistration {
return .socketChannel(self, interested)
}
@ -632,13 +655,36 @@ final class SocketChannel: BaseSocketChannel<Socket> {
}
override fileprivate func connectSocket(to address: SocketAddress) throws -> Bool {
return try self.socket.connect(to: address)
if try self.socket.connect(to: address) {
return true
}
let timeout = connectTimeout
connectTimeoutScheduled = eventLoop.scheduleTask(in: timeout) { () -> (Void) in
if self.pendingConnect != nil {
// The connection was still not established, close the Channel which will also fail the pending promise.
self.close0(error: ChannelError.connectTimeout(timeout), promise: nil)
}
}
return false
}
override fileprivate func finishConnectSocket() throws {
if let scheduled = connectTimeoutScheduled {
// Connection established so cancel the previous scheduled timeout.
connectTimeoutScheduled = nil
scheduled.cancel()
}
try self.socket.finishConnect()
becomeActive0()
}
override func close0(error: Error, promise: EventLoopPromise<Void>?) {
if let timeout = connectTimeoutScheduled {
connectTimeoutScheduled = nil
timeout.cancel()
}
super.close0(error: error, promise: promise)
}
}
/// A `Channel` for a server socket.
@ -912,7 +958,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
private let pendingWrites: PendingWritesManager
fileprivate var readPending = false
private var neverRegistered = true
private var pendingConnect: EventLoopPromise<Void>?
fileprivate var pendingConnect: EventLoopPromise<Void>?
private let closePromise: EventLoopPromise<Void>
private var active: Atomic<Bool> = Atomic(value: false)
public var isActive: Bool {
@ -1152,7 +1198,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
}
}
public final func close0(error: Error, promise: EventLoopPromise<Void>?) {
public func close0(error: Error, promise: EventLoopPromise<Void>?) {
assert(eventLoop.inEventLoop)
if closed {
@ -1421,7 +1467,10 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
public enum ChannelError: Error {
/// Tried to connect on a `Channel` that is already connecting.
case connectPending
/// Connect operation timed out
case connectTimeout(TimeAmount)
/// Unsupported operation triggered on a `Channel`. For example `connect` on a `ServerSocketChannel`.
case operationUnsupported
@ -1435,3 +1484,25 @@ public enum ChannelError: Error {
/// open locally.
case eof
}
extension ChannelError: Equatable {
public static func ==(lhs: ChannelError, rhs: ChannelError) -> Bool {
switch (lhs, rhs) {
case (.connectPending, .connectPending):
return true
case (.connectTimeout(_), .connectTimeout(_)):
return true
case (.operationUnsupported, .operationUnsupported):
return true
case (.ioOnClosedChannel, .ioOnClosedChannel):
return true
case (.alreadyClosed, .alreadyClosed):
return true
case (.eof, .eof):
return true
default:
return false
}
}
}

View File

@ -139,6 +139,15 @@ public enum WriteBufferWaterMarkOption: ChannelOption {
case const(())
}
/// `ConnectTimeoutOption` allows to configure the `TimeAmount` after which a connect will fail if it was not established in the meantime.
public enum ConnectTimeoutOption: ChannelOption {
public typealias AssociatedValueType = ()
public typealias OptionType = TimeAmount
case const(())
}
/// Provides `ChannelOption`s to be used with a `Channel`, `Bootstrap` or `ServerBootstrap`.
public struct ChannelOptions {
/// - seealso: `SocketOption`.
@ -164,4 +173,7 @@ public struct ChannelOptions {
/// - seealso: `WriteBufferWaterMarkOption`.
public static let writeBufferWaterMark = WriteBufferWaterMarkOption.const(())
/// - seealso: `ConnectTimeoutOption`.
public static let connectTimeout = ConnectTimeoutOption.const(())
}

View File

@ -46,6 +46,7 @@ extension ChannelTests {
("testPendingWritesCloseDuringVectorWrite", testPendingWritesCloseDuringVectorWrite),
("testPendingWritesMoreThanWritevIOVectorLimit", testPendingWritesMoreThanWritevIOVectorLimit),
("testPendingWritesIsHappyWhenSendfileReturnsWouldBlockButWroteFully", testPendingWritesIsHappyWhenSendfileReturnsWouldBlockButWroteFully),
("testConnectTimeout", testConnectTimeout),
]
}
}

View File

@ -1010,4 +1010,25 @@ public class ChannelTests: XCTestCase {
XCTAssertEqual(.writtenCompletely, result)
}
}
func testConnectTimeout() throws {
let group = MultiThreadedEventLoopGroup(numThreads: 1)
defer {
try! group.syncShutdownGracefully()
}
do {
// This must throw as 198.51.100.254 is reserved for documentation only
_ = try ClientBootstrap(group: group)
.channelOption(option: ChannelOptions.connectTimeout, value: .milliseconds(10))
.connect(to: SocketAddress.newAddressResolving(host: "198.51.100.254", port: 65535)).wait()
XCTFail()
} catch let err as ChannelError {
if case .connectTimeout(_) = err {
// expected, sadly there is no "if not case"
} else {
XCTFail()
}
}
}
}