From 4d8cf1ff22f4032260107298b46e346577b9fab2 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Thu, 28 Dec 2017 13:40:31 +0100 Subject: [PATCH] Add support for closing only the output / input side of the Channel. --- Sources/NIO/Channel.swift | 173 +++++++++++++++---- Sources/NIO/ChannelHandler.swift | 6 +- Sources/NIO/ChannelInvoker.swift | 25 ++- Sources/NIO/ChannelOption.swift | 14 ++ Sources/NIO/ChannelPipeline.swift | 37 +++-- Sources/NIO/Embedded.swift | 2 +- Sources/NIO/Socket.swift | 7 + Sources/NIO/System.swift | 30 +++- Sources/NIOOpenSSL/OpenSSLHandler.swift | 8 +- Tests/NIOTests/ChannelTests+XCTest.swift | 3 + Tests/NIOTests/ChannelTests.swift | 194 +++++++++++++++++++++- Tests/NIOTests/EchoServerClientTest.swift | 33 ---- Tests/NIOTests/FileRegionTest.swift | 34 ---- Tests/NIOTests/TestUtils.swift | 50 ++++++ 14 files changed, 493 insertions(+), 123 deletions(-) create mode 100644 Tests/NIOTests/TestUtils.swift diff --git a/Sources/NIO/Channel.swift b/Sources/NIO/Channel.swift index b9f1fd00..c047a99c 100644 --- a/Sources/NIO/Channel.swift +++ b/Sources/NIO/Channel.swift @@ -462,8 +462,11 @@ private struct PendingWritesState { } /// Fail all the outstanding writes. This is useful if for example the `Channel` is closed. - func failAll(error: Error) { - self.closed = true + func failAll(error: Error, close: Bool) { + if close { + assert(!self.closed) + self.closed = true + } self.state.failAll(error: error)() @@ -527,6 +530,9 @@ final class SocketChannel: BaseSocketChannel { private var connectTimeout = TimeAmount.seconds(10) private var connectTimeoutScheduled: Scheduled? + private var allowRemoteHalfClosure: Bool = false + private var inputShutdown: Bool = false + private var outputShutdown: Bool = false init(eventLoop: SelectableEventLoop, protocolFamily: Int32) throws { let socket = try Socket(protocolFamily: protocolFamily) @@ -544,6 +550,8 @@ final class SocketChannel: BaseSocketChannel { switch option { case _ as ConnectTimeoutOption: connectTimeout = value as! TimeAmount + case _ as AllowRemoteHalfClosureOption: + allowRemoteHalfClosure = value as! Bool default: try super.setOption0(option: option, value: value) } @@ -554,6 +562,8 @@ final class SocketChannel: BaseSocketChannel { switch option { case _ as ConnectTimeoutOption: return connectTimeout as! T.OptionType + case _ as AllowRemoteHalfClosureOption: + return allowRemoteHalfClosure as! T.OptionType default: return try super.getOption0(option: option) } @@ -578,7 +588,7 @@ final class SocketChannel: BaseSocketChannel { var buffer = recvAllocator.buffer(allocator: allocator) var result = ReadResult.none for i in 1...maxMessagesPerRead { - if closed { + if closed || inputShutdown { return result } switch try buffer.withMutableWritePointer(body: self.socket.read(pointer:size:)) { @@ -600,6 +610,12 @@ final class SocketChannel: BaseSocketChannel { } result = .some } else { + if inputShutdown { + // We received a EOF because we called shutdown on the fd by ourself, unregister from the Selector and return + readPending = false + unregisterForReadable() + return result + } // end-of-file throw ChannelError.eof } @@ -659,7 +675,7 @@ final class SocketChannel: BaseSocketChannel { connectTimeoutScheduled = eventLoop.scheduleTask(in: timeout) { () -> (Void) in if self.pendingConnect != nil { // The connection was still not established, close the Channel which will also fail the pending promise. - self.close0(error: ChannelError.connectTimeout(timeout), promise: nil) + self.close0(error: ChannelError.connectTimeout(timeout), mode: .all, promise: nil) } } return false @@ -675,13 +691,76 @@ final class SocketChannel: BaseSocketChannel { becomeActive0() } - override func close0(error: Error, promise: EventLoopPromise?) { - if let timeout = connectTimeoutScheduled { - connectTimeoutScheduled = nil - timeout.cancel() + override func close0(error: Error, mode: CloseMode, promise: EventLoopPromise?) { + do { + switch mode { + case .output: + if outputShutdown { + promise?.fail(error: ChannelError.outputClosed) + return + } + try socket.shutdown(how: .WR) + outputShutdown = true + // Fail all pending writes and so ensure all pending promises are notified + pendingWrites.failAll(error: error, close: false) + unregisterForWritable() + promise?.succeed(result: ()) + + pipeline.fireUserInboundEventTriggered(event: ChannelEvent.outputClosed) + + case .input: + if inputShutdown { + promise?.fail(error: ChannelError.inputClosed) + return + } + switch error { + case ChannelError.eof: + // No need to explicit call socket.shutdown(...) as we received an EOF and the call would only cause + // ENOTCON + break + default: + try socket.shutdown(how: .RD) + } + inputShutdown = true + unregisterForReadable() + promise?.succeed(result: ()) + + pipeline.fireUserInboundEventTriggered(event: ChannelEvent.inputClosed) + case .all: + if let timeout = connectTimeoutScheduled { + connectTimeoutScheduled = nil + timeout.cancel() + } + super.close0(error: error, mode: mode, promise: promise) + } + } catch let err { + promise?.fail(error: err) } - super.close0(error: error, promise: promise) } + + @discardableResult override func readIfNeeded0() -> Bool { + if inputShutdown { + return false + } + return super.readIfNeeded0() + } + + override public func read0(promise: EventLoopPromise?) { + if inputShutdown { + promise?.fail(error: ChannelError.inputClosed) + return + } + super.read0(promise: promise) + } + + override public func write0(data: IOData, promise: EventLoopPromise?) { + if outputShutdown { + promise?.fail(error: ChannelError.outputClosed) + return + } + super.write0(data: data, promise: promise) + } + } /// A `Channel` for a server socket. @@ -773,7 +852,7 @@ final class ServerSocketChannel : BaseSocketChannel { } override fileprivate func writeToSocket(pendingWrites: PendingWritesManager) throws -> WriteResult { - pendingWrites.failAll(error: ChannelError.operationUnsupported) + pendingWrites.failAll(error: ChannelError.operationUnsupported, close: false) return .writtenCompletely } @@ -805,7 +884,7 @@ public protocol ChannelCore : class { func write0(data: IOData, promise: EventLoopPromise?) func flush0(promise: EventLoopPromise?) func read0(promise: EventLoopPromise?) - func close0(error: Error, promise: EventLoopPromise?) + func close0(error: Error, mode: CloseMode, promise: EventLoopPromise?) func triggerUserOutboundEvent0(event: Any, promise: EventLoopPromise?) func channelRead0(data: NIOAny) func errorCaught0(error: Error) @@ -920,8 +999,8 @@ extension Channel { pipeline.read(promise: promise) } - public func close(promise: EventLoopPromise?) { - pipeline.close(promise: promise) + public func close(mode: CloseMode = .all, promise: EventLoopPromise?) { + pipeline.close(mode: mode, promise: promise) } public func register(promise: EventLoopPromise?) { @@ -953,12 +1032,13 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { let socket: T public var interestedEvent: IOEvent = .none + /// `true` if the whole `Channel` is closed and so no more IO operation can be done. public final var closed: Bool { assert(eventLoop.inEventLoop) return pendingWrites.closed } - private let pendingWrites: PendingWritesManager + fileprivate let pendingWrites: PendingWritesManager fileprivate var readPending = false private var neverRegistered = true fileprivate var pendingConnect: EventLoopPromise? @@ -1089,7 +1169,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { /// Triggers a `ChannelPipeline.read()` if `autoRead` is enabled.` /// /// - returns: `true` if `readPending` is `true`, `false` otherwise. - @discardableResult final func readIfNeeded0() -> Bool { + @discardableResult func readIfNeeded0() -> Bool { assert(eventLoop.inEventLoop) if !readPending && autoRead { @@ -1136,9 +1216,8 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { } } - private func unregisterForWritable() { + fileprivate func unregisterForWritable() { assert(eventLoop.inEventLoop) - switch interestedEvent { case .all: safeReregister(interested: .read) @@ -1165,7 +1244,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { } } - public final func read0(promise: EventLoopPromise?) { + public func read0(promise: EventLoopPromise?) { assert(eventLoop.inEventLoop) if closed { @@ -1201,7 +1280,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { } } - private func unregisterForReadable() { + fileprivate func unregisterForReadable() { assert(eventLoop.inEventLoop) switch interestedEvent { @@ -1214,7 +1293,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { } } - public func close0(error: Error, promise: EventLoopPromise?) { + public func close0(error: Error, mode: CloseMode, promise: EventLoopPromise?) { assert(eventLoop.inEventLoop) if closed { @@ -1222,6 +1301,11 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { return } + guard mode == .all else { + promise?.fail(error: ChannelError.operationUnsupported) + return + } + interestedEvent = .none do { try selectableEventLoop.deregister(channel: self) @@ -1237,7 +1321,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { } // Fail all pending writes and so ensure all pending promises are notified - self.pendingWrites.failAll(error: error) + self.pendingWrites.failAll(error: error, close: true) becomeInactive0() @@ -1323,25 +1407,30 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { do { try readFromSocket() } catch let err { - if let channelErr = err as? ChannelError { - // EOF is not really an error that should be forwarded to the user - if channelErr != ChannelError.eof { - pipeline.fireErrorCaught0(error: err) - } + // ChannelError.eof is not something we want to fire through the pipeline as it just means the remote + // per closed / shutdown the connection. + if let channelErr = err as? ChannelError, channelErr != ChannelError.eof { + pipeline.fireErrorCaught0(error: err) + } else if try! getOption(option: ChannelOptions.allowRemoteHalfClosure) { + // If we want to allow half closure we will just mark the input side of the Channel + // as closed. + pipeline.fireChannelReadComplete0() + close0(error: err, mode: .input, promise: nil) + readPending = false + return } else { pipeline.fireErrorCaught0(error: err) } - + // Call before triggering the close of the Channel. pipeline.fireChannelReadComplete0() - close0(error: err, promise: nil) - + close0(error: err, mode: .all, promise: nil) return } pipeline.fireChannelReadComplete0() readIfNeeded0() } - + fileprivate func connectSocket(to address: SocketAddress) throws -> Bool { fatalError("this must be overridden by sub class") } @@ -1424,7 +1513,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { try selectableEventLoop.reregister(channel: self) } catch let err { pipeline.fireErrorCaught0(error: err) - close0(error: err, promise: nil) + close0(error: err, mode: .all, promise: nil) } } @@ -1441,7 +1530,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { return true } catch let err { pipeline.fireErrorCaught0(error: err) - close0(error: err, promise: nil) + close0(error: err, mode: .all, promise: nil) return false } } @@ -1485,7 +1574,7 @@ class BaseSocketChannel : SelectableChannel, ChannelCore { } } - close0(error: err, promise: nil) + close0(error: err, mode: .all, promise: nil) // we handled all writes return true @@ -1537,6 +1626,12 @@ public enum ChannelError: Error { /// Close was called on a channel that is already closed. case alreadyClosed + + /// Output-side of the channel is closed. + case outputClosed + + /// Input-side of the channel is closed. + case inputClosed /// A read operation reached end-of-file. This usually means the remote peer closed the socket but it's still /// open locally. @@ -1556,6 +1651,10 @@ extension ChannelError: Equatable { return true case (.alreadyClosed, .alreadyClosed): return true + case (.outputClosed, .outputClosed): + return true + case (.inputClosed, .inputClosed): + return true case (.eof, .eof): return true default: @@ -1564,3 +1663,11 @@ extension ChannelError: Equatable { } } +/// An `Channel` related event that is passed through the `ChannelPipeline` to notify the user. +public enum ChannelEvent: Equatable { + /// `ChannelOptions.allowRemoteHalfClosure` is `true` and input portion of the `Channel` was closed. + case inputClosed + /// Output portion of the `Channel` was closed. + case outputClosed +} + diff --git a/Sources/NIO/ChannelHandler.swift b/Sources/NIO/ChannelHandler.swift index 45a2bf01..ed40e392 100644 --- a/Sources/NIO/ChannelHandler.swift +++ b/Sources/NIO/ChannelHandler.swift @@ -25,7 +25,7 @@ public protocol _ChannelOutboundHandler : ChannelHandler { func flush(ctx: ChannelHandlerContext, promise: EventLoopPromise?) // TODO: Think about make this more flexible in terms of influence the allocation that is used to read the next amount of data func read(ctx: ChannelHandlerContext, promise: EventLoopPromise?) - func close(ctx: ChannelHandlerContext, promise: EventLoopPromise?) + func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) func triggerUserOutboundEvent(ctx: ChannelHandlerContext, event: Any, promise: EventLoopPromise?) } @@ -79,8 +79,8 @@ extension _ChannelOutboundHandler { ctx.read(promise: promise) } - public func close(ctx: ChannelHandlerContext, promise: EventLoopPromise?) { - ctx.close(promise: promise) + public func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { + ctx.close(mode: mode, promise: promise) } public func triggerUserOutboundEvent(ctx: ChannelHandlerContext, event: Any, promise: EventLoopPromise?) { diff --git a/Sources/NIO/ChannelInvoker.swift b/Sources/NIO/ChannelInvoker.swift index 8df1053f..f13fc060 100644 --- a/Sources/NIO/ChannelInvoker.swift +++ b/Sources/NIO/ChannelInvoker.swift @@ -121,15 +121,18 @@ public protocol ChannelOutboundInvoker { /// Close the `Channel` and so the connection if one exists. /// + /// - parameters: + /// - mode: the `CloseMode` that is used /// - returns: the future which will be notified once the operation completes. - func close() -> EventLoopFuture + func close(mode: CloseMode) -> EventLoopFuture /// Close the `Channel` and so the connection if one exists. /// /// - parameters: + /// - mode: the `CloseMode` that is used /// - promise: the `EventLoopPromise` that will be notified once the operation completes, /// or `nil` if not interested in the outcome of the operation. - func close(promise: EventLoopPromise?) + func close(mode: CloseMode, promise: EventLoopPromise?) /// Trigger a custom user outbound event which will flow through the `ChannelPipeline`. /// @@ -192,9 +195,9 @@ extension ChannelOutboundInvoker { return promise.futureResult } - public func close() -> EventLoopFuture { + public func close(mode: CloseMode = .all) -> EventLoopFuture { let promise = newPromise() - close(promise: promise) + close(mode: mode, promise: promise) return promise.futureResult } @@ -267,3 +270,17 @@ public protocol ChannelInboundInvoker { /// A protocol that signals that outbound and inbound events are triggered by this invoker. public protocol ChannelInvoker : ChannelOutboundInvoker, ChannelInboundInvoker { } + +/// Specify what kind of close operation is requested. +public enum CloseMode { + /// Close the output (writing) side of the `Channel` without closing the actual file descriptor. + /// This is an optional mode which means it may not be supported by all `Channel` implementations. + case output + + /// Close the input (reading) side of the `Channel` without closing the actual file descriptor. + /// This is an optional mode which means it may not be supported by all `Channel` implementations. + case input + + /// Close the whole `Channel (file descriptor). + case all +} diff --git a/Sources/NIO/ChannelOption.swift b/Sources/NIO/ChannelOption.swift index d5b63d50..60dd1ca5 100644 --- a/Sources/NIO/ChannelOption.swift +++ b/Sources/NIO/ChannelOption.swift @@ -147,6 +147,17 @@ public enum ConnectTimeoutOption: ChannelOption { case const(()) } +/// `AllowRemoteHalfClosureOption` allows users to configure whether the `Channel` will close itself when its remote +/// peer shuts down its send stream, or whether it will remain open. If set to `false` (the default), the `Channel` +/// will be closed automatically if the remote peer shuts down its send stream. If set to true, the `Channel` will +/// not be closed: instead, a `ChannelEvent.inboundClosed` user event will be sent on the `ChannelPipeline`, +/// and no more data will be received. +public enum AllowRemoteHalfClosureOption: ChannelOption { + public typealias AssociatedValueType = () + public typealias OptionType = Bool + + case const(()) +} /// Provides `ChannelOption`s to be used with a `Channel`, `Bootstrap` or `ServerBootstrap`. public struct ChannelOptions { @@ -176,4 +187,7 @@ public struct ChannelOptions { /// - seealso: `ConnectTimeoutOption`. public static let connectTimeout = ConnectTimeoutOption.const(()) + + /// - seealso: `AllowRemoteHalfClosureOption`. + public static let allowRemoteHalfClosure = AllowRemoteHalfClosureOption.const(()) } diff --git a/Sources/NIO/ChannelPipeline.swift b/Sources/NIO/ChannelPipeline.swift index 827518f9..cfb36036 100644 --- a/Sources/NIO/ChannelPipeline.swift +++ b/Sources/NIO/ChannelPipeline.swift @@ -478,12 +478,12 @@ public final class ChannelPipeline : ChannelInvoker { } } - public func close(promise: EventLoopPromise?) { + public func close(mode: CloseMode = .all, promise: EventLoopPromise?) { if eventLoop.inEventLoop { - close0(promise: promise) + close0(mode: mode, promise: promise) } else { eventLoop.execute { - self.close0(promise: promise) + self.close0(mode: mode, promise: promise) } } } @@ -578,9 +578,9 @@ public final class ChannelPipeline : ChannelInvoker { return self.head?.inboundNext } - func close0(promise: EventLoopPromise?) { + func close0(mode: CloseMode, promise: EventLoopPromise?) { if let firstOutboundCtx = firstOutboundCtx { - firstOutboundCtx.invokeClose(promise: promise) + firstOutboundCtx.invokeClose(mode: mode, promise: promise) } else { promise?.fail(error: ChannelError.alreadyClosed) } @@ -768,9 +768,9 @@ private final class HeadChannelHandler : _ChannelOutboundHandler { } } - func close(ctx: ChannelHandlerContext, promise: EventLoopPromise?) { + func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { if let channel = ctx.channel { - channel._unsafe.close0(error: ChannelError.alreadyClosed, promise: promise) + channel._unsafe.close0(error: mode.error, mode: mode, promise: promise) } else { promise?.fail(error: ChannelError.alreadyClosed) } @@ -791,6 +791,20 @@ private final class HeadChannelHandler : _ChannelOutboundHandler { promise?.fail(error: ChannelError.ioOnClosedChannel) } } + +} + +private extension CloseMode { + var error: ChannelError { + switch self { + case .all: + return ChannelError.alreadyClosed + case .output: + return ChannelError.outputClosed + case .input: + return ChannelError.inputClosed + } + } } /// Special `ChannelInboundHandler` which will consume all inbound events. @@ -1070,10 +1084,11 @@ public final class ChannelHandlerContext : ChannelInvoker { /// When the `close` event reaches the `HeadChannelHandler` the socket will be closed. /// /// - parameters: + /// - mode: The `CloseMode` to use. /// - promise: The promise fulfilled when the `Channel` has been closed or failed if it the closing failed. - public func close(promise: EventLoopPromise?) { + public func close(mode: CloseMode = .all, promise: EventLoopPromise?) { if let outboundNext = outboundNext { - outboundNext.invokeClose(promise: promise) + outboundNext.invokeClose(mode: mode, promise: promise) } else { promise?.fail(error: ChannelError.alreadyClosed) } @@ -1256,11 +1271,11 @@ public final class ChannelHandlerContext : ChannelInvoker { self.outboundHandler.read(ctx: self, promise: promise) } - fileprivate func invokeClose(promise: EventLoopPromise?) { + fileprivate func invokeClose(mode: CloseMode, promise: EventLoopPromise?) { assert(inEventLoop) assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled") - self.outboundHandler.close(ctx: self, promise: promise) + self.outboundHandler.close(ctx: self, mode: mode, promise: promise) } fileprivate func invokeTriggerUserOutboundEvent(event: Any, promise: EventLoopPromise?) { diff --git a/Sources/NIO/Embedded.swift b/Sources/NIO/Embedded.swift index 7b3ed50d..60e6d6b1 100644 --- a/Sources/NIO/Embedded.swift +++ b/Sources/NIO/Embedded.swift @@ -173,7 +173,7 @@ class EmbeddedChannelCore : ChannelCore { var outboundBuffer: [IOData] = [] var inboundBuffer: [NIOAny] = [] - func close0(error: Error, promise: EventLoopPromise?) { + func close0(error: Error, mode: CloseMode, promise: EventLoopPromise?) { if closed { promise?.fail(error: ChannelError.alreadyClosed) return diff --git a/Sources/NIO/Socket.swift b/Sources/NIO/Socket.swift index c1a360d0..9f013323 100644 --- a/Sources/NIO/Socket.swift +++ b/Sources/NIO/Socket.swift @@ -116,4 +116,11 @@ final class Socket : BaseSocket { return try LinuxSocket.sendmmsg(sockfd: self.descriptor, msgvec: msgs.baseAddress!, vlen: CUnsignedInt(msgs.count), flags: 0) } #endif + + func shutdown(how: Shutdown) throws { + guard self.open else { + throw IOError(errnoCode: EBADF, reason: "can't shutdown socket as it's not open anymore.") + } + try Posix.shutdown(descriptor: self.descriptor, how: how) + } } diff --git a/Sources/NIO/System.swift b/Sources/NIO/System.swift index c49bee8d..bff7846e 100644 --- a/Sources/NIO/System.swift +++ b/Sources/NIO/System.swift @@ -76,6 +76,23 @@ internal func wrapSyscall(_ fn: () throws -> T, where func } } +enum Shutdown { + case RD + case WR + case RDWR + + fileprivate var cValue: CInt { + switch self { + case .RD: + return CInt(SHUT_RD) + case .WR: + return CInt(SHUT_WR) + case .RDWR: + return CInt(SHUT_RDWR) + } + } +} + internal enum Posix { #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) static let SOCK_STREAM: CInt = CInt(Darwin.SOCK_STREAM) @@ -91,7 +108,18 @@ internal enum Posix { fatalError("unsupported OS") } #endif - + + @inline(never) + public static func shutdown(descriptor: Int32, how: Shutdown) throws { + _ = try wrapSyscall({ () -> Int in + #if os(Linux) + return Int(Glibc.shutdown(descriptor, how.cValue)) + #else + return Int(Darwin.shutdown(descriptor, how.cValue)) + #endif + }) + } + @inline(never) public static func close(descriptor: Int32) throws { _ = try wrapSyscall({ () -> Int in diff --git a/Sources/NIOOpenSSL/OpenSSLHandler.swift b/Sources/NIOOpenSSL/OpenSSLHandler.swift index 4d55d390..8547192c 100644 --- a/Sources/NIOOpenSSL/OpenSSLHandler.swift +++ b/Sources/NIOOpenSSL/OpenSSLHandler.swift @@ -140,7 +140,13 @@ public class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandler { doUnbufferWrites(ctx: ctx) } - public func close(ctx: ChannelHandlerContext, promise: EventLoopPromise?) { + public func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { + guard mode == .all else { + // TODO: Support also other modes ? + promise?.fail(error: ChannelError.operationUnsupported) + return + } + switch state { case .closing: // We're in the process of TLS shutdown, so let's let that happen. However, diff --git a/Tests/NIOTests/ChannelTests+XCTest.swift b/Tests/NIOTests/ChannelTests+XCTest.swift index 4cd7cdf1..565b402b 100644 --- a/Tests/NIOTests/ChannelTests+XCTest.swift +++ b/Tests/NIOTests/ChannelTests+XCTest.swift @@ -47,6 +47,9 @@ extension ChannelTests { ("testPendingWritesMoreThanWritevIOVectorLimit", testPendingWritesMoreThanWritevIOVectorLimit), ("testPendingWritesIsHappyWhenSendfileReturnsWouldBlockButWroteFully", testPendingWritesIsHappyWhenSendfileReturnsWouldBlockButWroteFully), ("testConnectTimeout", testConnectTimeout), + ("testCloseOutput", testCloseOutput), + ("testCloseInput", testCloseInput), + ("testHalfClosure", testHalfClosure), ] } } diff --git a/Tests/NIOTests/ChannelTests.swift b/Tests/NIOTests/ChannelTests.swift index c4ab4228..d70d18e3 100644 --- a/Tests/NIOTests/ChannelTests.swift +++ b/Tests/NIOTests/ChannelTests.swift @@ -613,7 +613,7 @@ public class ChannelTests: XCTestCase { promiseStates: [[false, false, false], [false, false, false]]) XCTAssertEqual(WriteResult.wouldBlock, result) - pwm.failAll(error: ChannelError.operationUnsupported) + pwm.failAll(error: ChannelError.operationUnsupported, close: true) XCTAssertTrue(ps.map { $0.futureResult.fulfilled }.reduce(true) { $0 && $1 }) } @@ -939,7 +939,7 @@ public class ChannelTests: XCTestCase { _ = pwm.add(data: .byteBuffer(buffer), promise: ps[2]) ps[0].futureResult.whenComplete { _ in - pwm.failAll(error: ChannelError.eof) + pwm.failAll(error: ChannelError.inputClosed, close: true) } let result = try assertExpectedWritability(pendingWritesManager: pwm, @@ -1031,4 +1031,194 @@ public class ChannelTests: XCTestCase { } } } + + func testCloseOutput() throws { + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let server = try ServerSocket(protocolFamily: PF_INET) + defer { + XCTAssertNoThrow(try server.close()) + } + try server.bind(to: SocketAddress.newAddressResolving(host: "127.0.0.1", port: 0)) + try server.listen() + + let byteCountingHandler = ByteCountingHandler(numBytes: 4, promise: group.next().newPromise()) + + let future = ClientBootstrap(group: group) + .channelInitializer { channel in + return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: false, outputShutdown: true)).then { _ in + return channel.pipeline.add(handler: byteCountingHandler) + } + } + .connect(to: server.localAddress!) + let accepted = try server.accept()! + defer { + XCTAssertNoThrow(try accepted.close()) + } + + let channel = try future.wait() + defer { + XCTAssertNoThrow(try channel.close(mode: .all).wait()) + } + + var buffer = channel.allocator.buffer(capacity: 12) + buffer.write(string: "1234") + + try channel.writeAndFlush(data: NIOAny(buffer)).wait() + try channel.close(mode: .output).wait() + + do { + try channel.writeAndFlush(data: NIOAny(buffer)).wait() + XCTFail() + } catch let err as ChannelError { + XCTAssertEqual(ChannelError.outputClosed, err) + } + let written = try buffer.withUnsafeReadableBytes { p in + try accepted.write(pointer: p.baseAddress!.assumingMemoryBound(to: UInt8.self), size: 4) + } + switch written { + case .processed(let numBytes): + XCTAssertEqual(4, numBytes) + default: + XCTFail() + } + try byteCountingHandler.assertReceived(buffer: buffer) + } + + func testCloseInput() throws { + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let server = try ServerSocket(protocolFamily: PF_INET) + defer { + XCTAssertNoThrow(try server.close()) + } + try server.bind(to: SocketAddress.newAddressResolving(host: "127.0.0.1", port: 0)) + try server.listen() + + class VerifyNoReadHandler : ChannelInboundHandler { + typealias InboundIn = ByteBuffer + + public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { + XCTFail("Received data: \(data)") + } + } + + let future = ClientBootstrap(group: group) + .channelInitializer { channel in + return channel.pipeline.add(handler: VerifyNoReadHandler()).then { _ in + return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false)) + } + } + .connect(to: server.localAddress!) + let accepted = try server.accept()! + defer { + XCTAssertNoThrow(try accepted.close()) + } + + let channel = try future.wait() + defer { + XCTAssertNoThrow(try channel.close(mode: .all).wait()) + } + + try channel.close(mode: .input).wait() + + var buffer = channel.allocator.buffer(capacity: 12) + buffer.write(string: "1234") + + let written = try buffer.withUnsafeReadableBytes { p in + try accepted.write(pointer: p.baseAddress!.assumingMemoryBound(to: UInt8.self), size: 4) + } + + switch written { + case .processed(let numBytes): + XCTAssertEqual(4, numBytes) + default: + XCTFail() + } + + try channel.eventLoop.submit { + // Dummy task execution to give some time for an actual read to open (which should not happen as we closed the input). + }.wait() + } + + func testHalfClosure() throws { + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { + try! group.syncShutdownGracefully() + } + + let server = try ServerSocket(protocolFamily: PF_INET) + defer { + XCTAssertNoThrow(try server.close()) + } + try server.bind(to: SocketAddress.newAddressResolving(host: "127.0.0.1", port: 0)) + try server.listen() + + let future = ClientBootstrap(group: group) + .channelInitializer { channel in + return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false)) + } + .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) + .connect(to: server.localAddress!) + let accepted = try server.accept()! + defer { + XCTAssertNoThrow(try accepted.close()) + } + + let channel = try future.wait() + defer { + XCTAssertNoThrow(try channel.close(mode: .all).wait()) + } + + try accepted.shutdown(how: .WR) + + var buffer = channel.allocator.buffer(capacity: 12) + buffer.write(string: "1234") + + try channel.writeAndFlush(data: NIOAny(buffer)).wait() + } + + private class ShutdownVerificationHandler: ChannelInboundHandler { + typealias InboundIn = ByteBuffer + + private var inputShutdownEventReceived = false + private var outputShutdownEventReceived = false + + private let inputShutdown: Bool + private let outputShutdown: Bool + + init(inputShutdown: Bool, outputShutdown: Bool) { + self.inputShutdown = inputShutdown + self.outputShutdown = outputShutdown + } + + public func userInboundEventTriggered(ctx: ChannelHandlerContext, event: Any) { + switch event { + case let ev as ChannelEvent: + switch ev { + case .inputClosed: + XCTAssertFalse(inputShutdownEventReceived) + inputShutdownEventReceived = true + case .outputClosed: + XCTAssertFalse(outputShutdownEventReceived) + outputShutdownEventReceived = true + } + + fallthrough + default: + ctx.fireUserInboundEventTriggered(event: event) + } + } + + public func channelInactive(ctx: ChannelHandlerContext) { + XCTAssertEqual(inputShutdown, inputShutdownEventReceived) + XCTAssertEqual(outputShutdown, outputShutdownEventReceived) + } + } } diff --git a/Tests/NIOTests/EchoServerClientTest.swift b/Tests/NIOTests/EchoServerClientTest.swift index f248a5a5..0047d7a6 100644 --- a/Tests/NIOTests/EchoServerClientTest.swift +++ b/Tests/NIOTests/EchoServerClientTest.swift @@ -251,38 +251,6 @@ class EchoServerClientTest : XCTestCase { try countingHandler.assertReceived(buffer: buffer) } - - private final class ByteCountingHandler : ChannelInboundHandler { - typealias InboundIn = ByteBuffer - - private let numBytes: Int - private let promise: EventLoopPromise - private var buffer: ByteBuffer! - - init(numBytes: Int, promise: EventLoopPromise) { - self.numBytes = numBytes - self.promise = promise - } - - func handlerAdded(ctx: ChannelHandlerContext) { - buffer = ctx.channel!.allocator.buffer(capacity: numBytes) - } - - func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { - var currentBuffer = self.unwrapInboundIn(data) - buffer.write(buffer: ¤tBuffer) - - if buffer.readableBytes == numBytes { - // Do something - promise.succeed(result: buffer) - } - } - - func assertReceived(buffer: ByteBuffer) throws { - let received = try promise.futureResult.wait() - XCTAssertEqual(buffer, received) - } - } private final class ChannelActiveHandler: ChannelInboundHandler { typealias InboundIn = ByteBuffer @@ -447,7 +415,6 @@ class EchoServerClientTest : XCTestCase { } func testCloseInInactive() throws { - let group = MultiThreadedEventLoopGroup(numThreads: 1) defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) diff --git a/Tests/NIOTests/FileRegionTest.swift b/Tests/NIOTests/FileRegionTest.swift index a22ff55b..4cb7e988 100644 --- a/Tests/NIOTests/FileRegionTest.swift +++ b/Tests/NIOTests/FileRegionTest.swift @@ -111,40 +111,6 @@ class FileRegionTest : XCTestCase { try futures.forEach { try $0.wait() } } - private final class ByteCountingHandler : ChannelInboundHandler { - typealias InboundIn = ByteBuffer - - private let numBytes: Int - private let promise: EventLoopPromise - private var buffer: ByteBuffer! - - init(numBytes: Int, promise: EventLoopPromise) { - self.numBytes = numBytes - self.promise = promise - } - - func handlerAdded(ctx: ChannelHandlerContext) { - buffer = ctx.channel!.allocator.buffer(capacity: numBytes) - if self.numBytes == 0 { - self.promise.succeed(result: buffer) - } - } - - func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { - var currentBuffer = self.unwrapInboundIn(data) - buffer.write(buffer: ¤tBuffer) - - if buffer.readableBytes == numBytes { - promise.succeed(result: buffer) - } - } - - func assertReceived(buffer: ByteBuffer) throws { - let received = try promise.futureResult.wait() - XCTAssertEqual(buffer, received) - } - } - func testOutstandingFileRegionsWork() throws { let group = MultiThreadedEventLoopGroup(numThreads: 1) defer { diff --git a/Tests/NIOTests/TestUtils.swift b/Tests/NIOTests/TestUtils.swift new file mode 100644 index 00000000..becc6035 --- /dev/null +++ b/Tests/NIOTests/TestUtils.swift @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIO +import XCTest + +final class ByteCountingHandler : ChannelInboundHandler { + typealias InboundIn = ByteBuffer + + private let numBytes: Int + private let promise: EventLoopPromise + private var buffer: ByteBuffer! + + init(numBytes: Int, promise: EventLoopPromise) { + self.numBytes = numBytes + self.promise = promise + } + + func handlerAdded(ctx: ChannelHandlerContext) { + buffer = ctx.channel!.allocator.buffer(capacity: numBytes) + if self.numBytes == 0 { + self.promise.succeed(result: buffer) + } + } + + func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { + var currentBuffer = self.unwrapInboundIn(data) + buffer.write(buffer: ¤tBuffer) + + if buffer.readableBytes == numBytes { + promise.succeed(result: buffer) + } + } + + func assertReceived(buffer: ByteBuffer) throws { + let received = try promise.futureResult.wait() + XCTAssertEqual(buffer, received) + } +}