diff --git a/Sources/NIO/Channel.swift b/Sources/NIO/Channel.swift index 9881156f..d12a5378 100644 --- a/Sources/NIO/Channel.swift +++ b/Sources/NIO/Channel.swift @@ -1141,13 +1141,15 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { } catch let err { promise?.fail(error: err) } - if !neverRegistered { - pipeline.fireChannelUnregistered0() - } - pipeline.fireChannelInactive0() // Fail all pending writes and so ensure all pending promises are notified self.pendingWrites.failAll(error: error) + + if !neverRegistered { + pipeline.fireChannelUnregistered0() + } + + pipeline.fireChannelInactive0() eventLoop.execute { // ensure this is executed in a delayed fashion as the users code may still traverse the pipeline diff --git a/Sources/NIO/ChannelPipeline.swift b/Sources/NIO/ChannelPipeline.swift index e06698f1..066ae776 100644 --- a/Sources/NIO/ChannelPipeline.swift +++ b/Sources/NIO/ChannelPipeline.swift @@ -548,35 +548,67 @@ private final class HeadChannelHandler : _ChannelOutboundHandler { private init() { } func register(ctx: ChannelHandlerContext, promise: Promise?) { - ctx.channel!._unsafe.register0(promise: promise) + if let channel = ctx.channel { + channel._unsafe.register0(promise: promise) + } else { + promise?.fail(error: ChannelError.ioOnClosedChannel) + } } func bind(ctx: ChannelHandlerContext, to address: SocketAddress, promise: Promise?) { - ctx.channel!._unsafe.bind0(to: address, promise: promise) + if let channel = ctx.channel { + channel._unsafe.bind0(to: address, promise: promise) + } else { + promise?.fail(error: ChannelError.ioOnClosedChannel) + } } func connect(ctx: ChannelHandlerContext, to address: SocketAddress, promise: Promise?) { - ctx.channel!._unsafe.connect0(to: address, promise: promise) + if let channel = ctx.channel { + channel._unsafe.connect0(to: address, promise: promise) + } else { + promise?.fail(error: ChannelError.ioOnClosedChannel) + } } func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise?) { - ctx.channel!._unsafe.write0(data: data, promise: promise) + if let channel = ctx.channel { + channel._unsafe.write0(data: data, promise: promise) + } else { + promise?.fail(error: ChannelError.ioOnClosedChannel) + } } func flush(ctx: ChannelHandlerContext, promise: Promise?) { - ctx.channel!._unsafe.flush0(promise: promise) + if let channel = ctx.channel { + channel._unsafe.flush0(promise: promise) + } else { + promise?.fail(error: ChannelError.ioOnClosedChannel) + } } func close(ctx: ChannelHandlerContext, promise: Promise?) { - ctx.channel!._unsafe.close0(error: ChannelError.alreadyClosed, promise: promise) + if let channel = ctx.channel { + channel._unsafe.close0(error: ChannelError.alreadyClosed, promise: promise) + } else { + promise?.fail(error: ChannelError.alreadyClosed) + } } func read(ctx: ChannelHandlerContext, promise: Promise?) { - ctx.channel!._unsafe.read0(promise: promise) + if let channel = ctx.channel { + channel._unsafe.read0(promise: promise) + } else { + promise?.fail(error: ChannelError.ioOnClosedChannel) + } } func triggerUserOutboundEvent(ctx: ChannelHandlerContext, event: Any, promise: Promise?) { - ctx.channel!._unsafe.triggerUserOutboundEvent0(event: event, promise: promise) + if let channel = ctx.channel { + channel._unsafe.triggerUserOutboundEvent0(event: event, promise: promise) + } else { + promise?.fail(error: ChannelError.ioOnClosedChannel) + } } } diff --git a/Tests/NIOTests/EchoServerClientTest.swift b/Tests/NIOTests/EchoServerClientTest.swift index 8b447f11..2b97a8bc 100644 --- a/Tests/NIOTests/EchoServerClientTest.swift +++ b/Tests/NIOTests/EchoServerClientTest.swift @@ -14,6 +14,7 @@ import Foundation import XCTest +import ConcurrencyHelpers @testable import NIO class EchoServerClientTest : XCTestCase { @@ -180,4 +181,98 @@ class EchoServerClientTest : XCTestCase { ctx.flush(promise: nil) } } + + private final class CloseInInActiveAndUnregisteredChannelHandler: ChannelInboundHandler { + typealias InboundIn = Never + let alreadyClosedInChannelInactive = Atomic(value: false) + let alreadyClosedInChannelUnregistered = Atomic(value: false) + let channelUnregisteredPromise: Promise<()> + let channelInactivePromise: Promise<()> + + public init(channelUnregisteredPromise: Promise<()>, + channelInactivePromise: Promise<()>) { + self.channelUnregisteredPromise = channelUnregisteredPromise + self.channelInactivePromise = channelInactivePromise + } + + public func channelActive(ctx: ChannelHandlerContext) { + ctx.close().whenComplete { val in + switch val { + case .success(()): + () + default: + XCTFail("bad, initial close failed") + } + } + } + + public func channelInactive(ctx: ChannelHandlerContext) { + if alreadyClosedInChannelInactive.compareAndExchange(expected: false, desired: true) { + ctx.close().whenComplete { val in + switch val { + case .failure(ChannelError.alreadyClosed): + () + case .success(()): + XCTFail("unexpected success") + case .failure(let e): + XCTFail("unexpected error: \(e)") + } + self.channelInactivePromise.succeed(result: ()) + } + } + } + + public func channelUnregistered(ctx: ChannelHandlerContext) { + if alreadyClosedInChannelUnregistered.compareAndExchange(expected: false, desired: true) { + ctx.close().whenComplete { val in + switch val { + case .failure(ChannelError.alreadyClosed): + () + case .success(()): + XCTFail("unexpected success") + case .failure(let e): + XCTFail("unexpected error: \(e)") + } + self.channelUnregisteredPromise.succeed(result: ()) + } + } + } + } + + func testCloseInInactive() throws { + + let group = try MultiThreadedEventLoopGroup(numThreads: 1) + defer { + _ = try? group.close() + } + + let inactivePromise = group.next().newPromise() as Promise<()> + let unregistredPromise = group.next().newPromise() as Promise<()> + let handler = CloseInInActiveAndUnregisteredChannelHandler(channelUnregisteredPromise: unregistredPromise, + channelInactivePromise: inactivePromise) + + let serverChannel = try ServerBootstrap(group: group) + .option(option: ChannelOptions.Socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + + // Set the handlers that are appled to the accepted Channels + .handler(childHandler: ChannelInitializer(initChannel: { channel in + // Ensure we not read faster then we can write by adding the BackPressureHandler into the pipeline. + return channel.pipeline.add(handler: handler) + })).bind(to: "127.0.0.1", on: 0).wait() + + defer { + _ = serverChannel.close() + } + + let clientChannel = try ClientBootstrap(group: group).connect(to: serverChannel.localAddress!).wait() + + defer { + _ = clientChannel.close() + } + + _ = try inactivePromise.futureResult.and(unregistredPromise.futureResult).wait() + + XCTAssertTrue(handler.alreadyClosedInChannelInactive.load()) + XCTAssertTrue(handler.alreadyClosedInChannelUnregistered.load()) + } }