diff --git a/Sources/NIO/BaseSocket.swift b/Sources/NIO/BaseSocket.swift index 6c8e00a2..786c5131 100644 --- a/Sources/NIO/BaseSocket.swift +++ b/Sources/NIO/BaseSocket.swift @@ -394,7 +394,7 @@ class BaseSocket: Selectable { /// After the socket was closed all other methods will throw an `IOError` when called. /// /// - throws: An `IOError` if the operation failed. - final func close() throws { + func close() throws { try withUnsafeFileDescriptor { fd in try Posix.close(descriptor: fd) } diff --git a/Sources/NIO/BaseSocketChannel.swift b/Sources/NIO/BaseSocketChannel.swift index 0e19230f..458e16f7 100644 --- a/Sources/NIO/BaseSocketChannel.swift +++ b/Sources/NIO/BaseSocketChannel.swift @@ -664,11 +664,18 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { return } + // === BEGIN: No user callouts === + + // this is to register all error callouts as all the callouts must happen after we transition out state + var errorCallouts: [(ChannelPipeline) -> Void] = [] + self.interestedEvent = .reset do { try selectableEventLoop.deregister(channel: self) } catch let err { - pipeline.fireErrorCaught0(error: err) + errorCallouts.append { pipeline in + pipeline.fireErrorCaught0(error: err) + } } let p: EventLoopPromise? @@ -676,21 +683,32 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { try socket.close() p = promise } catch { - promise?.fail(error: error) - // Set p to nil as we want to ensure we pass nil to becomeInactive0(...) so we not try to notify the promise again. + errorCallouts.append { (_: ChannelPipeline) in + promise?.fail(error: error) + // Set p to nil as we want to ensure we pass nil to becomeInactive0(...) so we not try to notify the promise again. + } p = nil } // Transition our internal state. let callouts = self.lifecycleManager.close(promise: p) + // === END: No user callouts (now that our state is reconciled, we can call out to user code.) === + + // this must be the first to call out as it transitions the PendingWritesManager into the closed state + // and we assert elsewhere that the PendingWritesManager has the same idea of 'open' as we have in here. + self.cancelWritesOnClose(error: error) + + // this should be a no-op as we shouldn't have any + errorCallouts.forEach { + $0(self.pipeline) + } + if let connectPromise = self.pendingConnect { self.pendingConnect = nil connectPromise.fail(error: error) } - // Now that our state is sensible, we can call out to user code. - self.cancelWritesOnClose(error: error) callouts(self.pipeline) eventLoop.execute { diff --git a/Tests/NIOTests/ChannelTests+XCTest.swift b/Tests/NIOTests/ChannelTests+XCTest.swift index beaab11a..707d3fa3 100644 --- a/Tests/NIOTests/ChannelTests+XCTest.swift +++ b/Tests/NIOTests/ChannelTests+XCTest.swift @@ -68,6 +68,7 @@ extension ChannelTests { ("testSocketFailingAsyncCorrectlyTearsTheChannelDownAndDoesntCrash", testSocketFailingAsyncCorrectlyTearsTheChannelDownAndDoesntCrash), ("testSocketErroringSynchronouslyCorrectlyTearsTheChannelDown", testSocketErroringSynchronouslyCorrectlyTearsTheChannelDown), ("testConnectWithECONNREFUSEDGetsTheRightError", testConnectWithECONNREFUSEDGetsTheRightError), + ("testCloseInUnregister", testCloseInUnregister), ] } } diff --git a/Tests/NIOTests/ChannelTests.swift b/Tests/NIOTests/ChannelTests.swift index 2400ee45..50471e57 100644 --- a/Tests/NIOTests/ChannelTests.swift +++ b/Tests/NIOTests/ChannelTests.swift @@ -2263,6 +2263,60 @@ public class ChannelTests: XCTestCase { } } + func testCloseInUnregister() throws { + enum DummyError: Error { case dummy } + class SocketFailingClose: Socket { + init() throws { + try super.init(protocolFamily: PF_INET, type: Posix.SOCK_STREAM, setNonBlocking: true) + } + + override func close() throws { + _ = try? super.close() + throw DummyError.dummy + } + } + + let group = MultiThreadedEventLoopGroup(numThreads: 2) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + let sc = try SocketChannel(socket: SocketFailingClose(), eventLoop: group.next() as! SelectableEventLoop) + + let serverChannel = try ServerBootstrap(group: group.next()) + .bind(host: "127.0.0.1", port: 0) + .wait() + defer { + XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) + } + + XCTAssertNoThrow(try sc.eventLoop.submit { + sc.register().then { + sc.connect(to: serverChannel.localAddress!) + } + }.wait().wait() as Void) + + do { + try sc.eventLoop.submit { () -> EventLoopFuture in + let p: EventLoopPromise = sc.eventLoop.newPromise() + // this callback must be attached before we call the close + let f = p.futureResult.map { + XCTFail("shouldn't be reached") + }.thenIfError { err in + XCTAssertNotNil(err as? DummyError) + return sc.close() + } + sc.close(promise: p) + return f + }.wait().wait() + XCTFail("shouldn't be reached") + } catch ChannelError.alreadyClosed { + // ok + } catch { + XCTFail("wrong error: \(error)") + } + + } + } fileprivate class VerifyConnectionFailureHandler: ChannelInboundHandler {