Add support for closing only the output / input side of the Channel.

This commit is contained in:
Norman Maurer 2017-12-28 13:40:31 +01:00
parent cf790be04f
commit 4d8cf1ff22
14 changed files with 493 additions and 123 deletions

View File

@ -462,8 +462,11 @@ private struct PendingWritesState {
}
/// Fail all the outstanding writes. This is useful if for example the `Channel` is closed.
func failAll(error: Error) {
self.closed = true
func failAll(error: Error, close: Bool) {
if close {
assert(!self.closed)
self.closed = true
}
self.state.failAll(error: error)()
@ -527,6 +530,9 @@ final class SocketChannel: BaseSocketChannel<Socket> {
private var connectTimeout = TimeAmount.seconds(10)
private var connectTimeoutScheduled: Scheduled<Void>?
private var allowRemoteHalfClosure: Bool = false
private var inputShutdown: Bool = false
private var outputShutdown: Bool = false
init(eventLoop: SelectableEventLoop, protocolFamily: Int32) throws {
let socket = try Socket(protocolFamily: protocolFamily)
@ -544,6 +550,8 @@ final class SocketChannel: BaseSocketChannel<Socket> {
switch option {
case _ as ConnectTimeoutOption:
connectTimeout = value as! TimeAmount
case _ as AllowRemoteHalfClosureOption:
allowRemoteHalfClosure = value as! Bool
default:
try super.setOption0(option: option, value: value)
}
@ -554,6 +562,8 @@ final class SocketChannel: BaseSocketChannel<Socket> {
switch option {
case _ as ConnectTimeoutOption:
return connectTimeout as! T.OptionType
case _ as AllowRemoteHalfClosureOption:
return allowRemoteHalfClosure as! T.OptionType
default:
return try super.getOption0(option: option)
}
@ -578,7 +588,7 @@ final class SocketChannel: BaseSocketChannel<Socket> {
var buffer = recvAllocator.buffer(allocator: allocator)
var result = ReadResult.none
for i in 1...maxMessagesPerRead {
if closed {
if closed || inputShutdown {
return result
}
switch try buffer.withMutableWritePointer(body: self.socket.read(pointer:size:)) {
@ -600,6 +610,12 @@ final class SocketChannel: BaseSocketChannel<Socket> {
}
result = .some
} else {
if inputShutdown {
// We received a EOF because we called shutdown on the fd by ourself, unregister from the Selector and return
readPending = false
unregisterForReadable()
return result
}
// end-of-file
throw ChannelError.eof
}
@ -659,7 +675,7 @@ final class SocketChannel: BaseSocketChannel<Socket> {
connectTimeoutScheduled = eventLoop.scheduleTask(in: timeout) { () -> (Void) in
if self.pendingConnect != nil {
// The connection was still not established, close the Channel which will also fail the pending promise.
self.close0(error: ChannelError.connectTimeout(timeout), promise: nil)
self.close0(error: ChannelError.connectTimeout(timeout), mode: .all, promise: nil)
}
}
return false
@ -675,13 +691,76 @@ final class SocketChannel: BaseSocketChannel<Socket> {
becomeActive0()
}
override func close0(error: Error, promise: EventLoopPromise<Void>?) {
if let timeout = connectTimeoutScheduled {
connectTimeoutScheduled = nil
timeout.cancel()
override func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?) {
do {
switch mode {
case .output:
if outputShutdown {
promise?.fail(error: ChannelError.outputClosed)
return
}
try socket.shutdown(how: .WR)
outputShutdown = true
// Fail all pending writes and so ensure all pending promises are notified
pendingWrites.failAll(error: error, close: false)
unregisterForWritable()
promise?.succeed(result: ())
pipeline.fireUserInboundEventTriggered(event: ChannelEvent.outputClosed)
case .input:
if inputShutdown {
promise?.fail(error: ChannelError.inputClosed)
return
}
switch error {
case ChannelError.eof:
// No need to explicit call socket.shutdown(...) as we received an EOF and the call would only cause
// ENOTCON
break
default:
try socket.shutdown(how: .RD)
}
inputShutdown = true
unregisterForReadable()
promise?.succeed(result: ())
pipeline.fireUserInboundEventTriggered(event: ChannelEvent.inputClosed)
case .all:
if let timeout = connectTimeoutScheduled {
connectTimeoutScheduled = nil
timeout.cancel()
}
super.close0(error: error, mode: mode, promise: promise)
}
} catch let err {
promise?.fail(error: err)
}
super.close0(error: error, promise: promise)
}
@discardableResult override func readIfNeeded0() -> Bool {
if inputShutdown {
return false
}
return super.readIfNeeded0()
}
override public func read0(promise: EventLoopPromise<Void>?) {
if inputShutdown {
promise?.fail(error: ChannelError.inputClosed)
return
}
super.read0(promise: promise)
}
override public func write0(data: IOData, promise: EventLoopPromise<Void>?) {
if outputShutdown {
promise?.fail(error: ChannelError.outputClosed)
return
}
super.write0(data: data, promise: promise)
}
}
/// A `Channel` for a server socket.
@ -773,7 +852,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
}
override fileprivate func writeToSocket(pendingWrites: PendingWritesManager) throws -> WriteResult {
pendingWrites.failAll(error: ChannelError.operationUnsupported)
pendingWrites.failAll(error: ChannelError.operationUnsupported, close: false)
return .writtenCompletely
}
@ -805,7 +884,7 @@ public protocol ChannelCore : class {
func write0(data: IOData, promise: EventLoopPromise<Void>?)
func flush0(promise: EventLoopPromise<Void>?)
func read0(promise: EventLoopPromise<Void>?)
func close0(error: Error, promise: EventLoopPromise<Void>?)
func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?)
func triggerUserOutboundEvent0(event: Any, promise: EventLoopPromise<Void>?)
func channelRead0(data: NIOAny)
func errorCaught0(error: Error)
@ -920,8 +999,8 @@ extension Channel {
pipeline.read(promise: promise)
}
public func close(promise: EventLoopPromise<Void>?) {
pipeline.close(promise: promise)
public func close(mode: CloseMode = .all, promise: EventLoopPromise<Void>?) {
pipeline.close(mode: mode, promise: promise)
}
public func register(promise: EventLoopPromise<Void>?) {
@ -953,12 +1032,13 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
let socket: T
public var interestedEvent: IOEvent = .none
/// `true` if the whole `Channel` is closed and so no more IO operation can be done.
public final var closed: Bool {
assert(eventLoop.inEventLoop)
return pendingWrites.closed
}
private let pendingWrites: PendingWritesManager
fileprivate let pendingWrites: PendingWritesManager
fileprivate var readPending = false
private var neverRegistered = true
fileprivate var pendingConnect: EventLoopPromise<Void>?
@ -1089,7 +1169,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
/// Triggers a `ChannelPipeline.read()` if `autoRead` is enabled.`
///
/// - returns: `true` if `readPending` is `true`, `false` otherwise.
@discardableResult final func readIfNeeded0() -> Bool {
@discardableResult func readIfNeeded0() -> Bool {
assert(eventLoop.inEventLoop)
if !readPending && autoRead {
@ -1136,9 +1216,8 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
}
}
private func unregisterForWritable() {
fileprivate func unregisterForWritable() {
assert(eventLoop.inEventLoop)
switch interestedEvent {
case .all:
safeReregister(interested: .read)
@ -1165,7 +1244,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
}
}
public final func read0(promise: EventLoopPromise<Void>?) {
public func read0(promise: EventLoopPromise<Void>?) {
assert(eventLoop.inEventLoop)
if closed {
@ -1201,7 +1280,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
}
}
private func unregisterForReadable() {
fileprivate func unregisterForReadable() {
assert(eventLoop.inEventLoop)
switch interestedEvent {
@ -1214,7 +1293,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
}
}
public func close0(error: Error, promise: EventLoopPromise<Void>?) {
public func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?) {
assert(eventLoop.inEventLoop)
if closed {
@ -1222,6 +1301,11 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
return
}
guard mode == .all else {
promise?.fail(error: ChannelError.operationUnsupported)
return
}
interestedEvent = .none
do {
try selectableEventLoop.deregister(channel: self)
@ -1237,7 +1321,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
}
// Fail all pending writes and so ensure all pending promises are notified
self.pendingWrites.failAll(error: error)
self.pendingWrites.failAll(error: error, close: true)
becomeInactive0()
@ -1323,25 +1407,30 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
do {
try readFromSocket()
} catch let err {
if let channelErr = err as? ChannelError {
// EOF is not really an error that should be forwarded to the user
if channelErr != ChannelError.eof {
pipeline.fireErrorCaught0(error: err)
}
// ChannelError.eof is not something we want to fire through the pipeline as it just means the remote
// per closed / shutdown the connection.
if let channelErr = err as? ChannelError, channelErr != ChannelError.eof {
pipeline.fireErrorCaught0(error: err)
} else if try! getOption(option: ChannelOptions.allowRemoteHalfClosure) {
// If we want to allow half closure we will just mark the input side of the Channel
// as closed.
pipeline.fireChannelReadComplete0()
close0(error: err, mode: .input, promise: nil)
readPending = false
return
} else {
pipeline.fireErrorCaught0(error: err)
}
// Call before triggering the close of the Channel.
pipeline.fireChannelReadComplete0()
close0(error: err, promise: nil)
close0(error: err, mode: .all, promise: nil)
return
}
pipeline.fireChannelReadComplete0()
readIfNeeded0()
}
fileprivate func connectSocket(to address: SocketAddress) throws -> Bool {
fatalError("this must be overridden by sub class")
}
@ -1424,7 +1513,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
try selectableEventLoop.reregister(channel: self)
} catch let err {
pipeline.fireErrorCaught0(error: err)
close0(error: err, promise: nil)
close0(error: err, mode: .all, promise: nil)
}
}
@ -1441,7 +1530,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
return true
} catch let err {
pipeline.fireErrorCaught0(error: err)
close0(error: err, promise: nil)
close0(error: err, mode: .all, promise: nil)
return false
}
}
@ -1485,7 +1574,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
}
}
close0(error: err, promise: nil)
close0(error: err, mode: .all, promise: nil)
// we handled all writes
return true
@ -1537,6 +1626,12 @@ public enum ChannelError: Error {
/// Close was called on a channel that is already closed.
case alreadyClosed
/// Output-side of the channel is closed.
case outputClosed
/// Input-side of the channel is closed.
case inputClosed
/// A read operation reached end-of-file. This usually means the remote peer closed the socket but it's still
/// open locally.
@ -1556,6 +1651,10 @@ extension ChannelError: Equatable {
return true
case (.alreadyClosed, .alreadyClosed):
return true
case (.outputClosed, .outputClosed):
return true
case (.inputClosed, .inputClosed):
return true
case (.eof, .eof):
return true
default:
@ -1564,3 +1663,11 @@ extension ChannelError: Equatable {
}
}
/// An `Channel` related event that is passed through the `ChannelPipeline` to notify the user.
public enum ChannelEvent: Equatable {
/// `ChannelOptions.allowRemoteHalfClosure` is `true` and input portion of the `Channel` was closed.
case inputClosed
/// Output portion of the `Channel` was closed.
case outputClosed
}

View File

@ -25,7 +25,7 @@ public protocol _ChannelOutboundHandler : ChannelHandler {
func flush(ctx: ChannelHandlerContext, promise: EventLoopPromise<Void>?)
// TODO: Think about make this more flexible in terms of influence the allocation that is used to read the next amount of data
func read(ctx: ChannelHandlerContext, promise: EventLoopPromise<Void>?)
func close(ctx: ChannelHandlerContext, promise: EventLoopPromise<Void>?)
func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?)
func triggerUserOutboundEvent(ctx: ChannelHandlerContext, event: Any, promise: EventLoopPromise<Void>?)
}
@ -79,8 +79,8 @@ extension _ChannelOutboundHandler {
ctx.read(promise: promise)
}
public func close(ctx: ChannelHandlerContext, promise: EventLoopPromise<Void>?) {
ctx.close(promise: promise)
public func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
ctx.close(mode: mode, promise: promise)
}
public func triggerUserOutboundEvent(ctx: ChannelHandlerContext, event: Any, promise: EventLoopPromise<Void>?) {

View File

@ -121,15 +121,18 @@ public protocol ChannelOutboundInvoker {
/// Close the `Channel` and so the connection if one exists.
///
/// - parameters:
/// - mode: the `CloseMode` that is used
/// - returns: the future which will be notified once the operation completes.
func close() -> EventLoopFuture<Void>
func close(mode: CloseMode) -> EventLoopFuture<Void>
/// Close the `Channel` and so the connection if one exists.
///
/// - parameters:
/// - mode: the `CloseMode` that is used
/// - promise: the `EventLoopPromise` that will be notified once the operation completes,
/// or `nil` if not interested in the outcome of the operation.
func close(promise: EventLoopPromise<Void>?)
func close(mode: CloseMode, promise: EventLoopPromise<Void>?)
/// Trigger a custom user outbound event which will flow through the `ChannelPipeline`.
///
@ -192,9 +195,9 @@ extension ChannelOutboundInvoker {
return promise.futureResult
}
public func close() -> EventLoopFuture<Void> {
public func close(mode: CloseMode = .all) -> EventLoopFuture<Void> {
let promise = newPromise()
close(promise: promise)
close(mode: mode, promise: promise)
return promise.futureResult
}
@ -267,3 +270,17 @@ public protocol ChannelInboundInvoker {
/// A protocol that signals that outbound and inbound events are triggered by this invoker.
public protocol ChannelInvoker : ChannelOutboundInvoker, ChannelInboundInvoker { }
/// Specify what kind of close operation is requested.
public enum CloseMode {
/// Close the output (writing) side of the `Channel` without closing the actual file descriptor.
/// This is an optional mode which means it may not be supported by all `Channel` implementations.
case output
/// Close the input (reading) side of the `Channel` without closing the actual file descriptor.
/// This is an optional mode which means it may not be supported by all `Channel` implementations.
case input
/// Close the whole `Channel (file descriptor).
case all
}

View File

@ -147,6 +147,17 @@ public enum ConnectTimeoutOption: ChannelOption {
case const(())
}
/// `AllowRemoteHalfClosureOption` allows users to configure whether the `Channel` will close itself when its remote
/// peer shuts down its send stream, or whether it will remain open. If set to `false` (the default), the `Channel`
/// will be closed automatically if the remote peer shuts down its send stream. If set to true, the `Channel` will
/// not be closed: instead, a `ChannelEvent.inboundClosed` user event will be sent on the `ChannelPipeline`,
/// and no more data will be received.
public enum AllowRemoteHalfClosureOption: ChannelOption {
public typealias AssociatedValueType = ()
public typealias OptionType = Bool
case const(())
}
/// Provides `ChannelOption`s to be used with a `Channel`, `Bootstrap` or `ServerBootstrap`.
public struct ChannelOptions {
@ -176,4 +187,7 @@ public struct ChannelOptions {
/// - seealso: `ConnectTimeoutOption`.
public static let connectTimeout = ConnectTimeoutOption.const(())
/// - seealso: `AllowRemoteHalfClosureOption`.
public static let allowRemoteHalfClosure = AllowRemoteHalfClosureOption.const(())
}

View File

@ -478,12 +478,12 @@ public final class ChannelPipeline : ChannelInvoker {
}
}
public func close(promise: EventLoopPromise<Void>?) {
public func close(mode: CloseMode = .all, promise: EventLoopPromise<Void>?) {
if eventLoop.inEventLoop {
close0(promise: promise)
close0(mode: mode, promise: promise)
} else {
eventLoop.execute {
self.close0(promise: promise)
self.close0(mode: mode, promise: promise)
}
}
}
@ -578,9 +578,9 @@ public final class ChannelPipeline : ChannelInvoker {
return self.head?.inboundNext
}
func close0(promise: EventLoopPromise<Void>?) {
func close0(mode: CloseMode, promise: EventLoopPromise<Void>?) {
if let firstOutboundCtx = firstOutboundCtx {
firstOutboundCtx.invokeClose(promise: promise)
firstOutboundCtx.invokeClose(mode: mode, promise: promise)
} else {
promise?.fail(error: ChannelError.alreadyClosed)
}
@ -768,9 +768,9 @@ private final class HeadChannelHandler : _ChannelOutboundHandler {
}
}
func close(ctx: ChannelHandlerContext, promise: EventLoopPromise<Void>?) {
func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
if let channel = ctx.channel {
channel._unsafe.close0(error: ChannelError.alreadyClosed, promise: promise)
channel._unsafe.close0(error: mode.error, mode: mode, promise: promise)
} else {
promise?.fail(error: ChannelError.alreadyClosed)
}
@ -791,6 +791,20 @@ private final class HeadChannelHandler : _ChannelOutboundHandler {
promise?.fail(error: ChannelError.ioOnClosedChannel)
}
}
}
private extension CloseMode {
var error: ChannelError {
switch self {
case .all:
return ChannelError.alreadyClosed
case .output:
return ChannelError.outputClosed
case .input:
return ChannelError.inputClosed
}
}
}
/// Special `ChannelInboundHandler` which will consume all inbound events.
@ -1070,10 +1084,11 @@ public final class ChannelHandlerContext : ChannelInvoker {
/// When the `close` event reaches the `HeadChannelHandler` the socket will be closed.
///
/// - parameters:
/// - mode: The `CloseMode` to use.
/// - promise: The promise fulfilled when the `Channel` has been closed or failed if it the closing failed.
public func close(promise: EventLoopPromise<Void>?) {
public func close(mode: CloseMode = .all, promise: EventLoopPromise<Void>?) {
if let outboundNext = outboundNext {
outboundNext.invokeClose(promise: promise)
outboundNext.invokeClose(mode: mode, promise: promise)
} else {
promise?.fail(error: ChannelError.alreadyClosed)
}
@ -1256,11 +1271,11 @@ public final class ChannelHandlerContext : ChannelInvoker {
self.outboundHandler.read(ctx: self, promise: promise)
}
fileprivate func invokeClose(promise: EventLoopPromise<Void>?) {
fileprivate func invokeClose(mode: CloseMode, promise: EventLoopPromise<Void>?) {
assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
self.outboundHandler.close(ctx: self, promise: promise)
self.outboundHandler.close(ctx: self, mode: mode, promise: promise)
}
fileprivate func invokeTriggerUserOutboundEvent(event: Any, promise: EventLoopPromise<Void>?) {

View File

@ -173,7 +173,7 @@ class EmbeddedChannelCore : ChannelCore {
var outboundBuffer: [IOData] = []
var inboundBuffer: [NIOAny] = []
func close0(error: Error, promise: EventLoopPromise<Void>?) {
func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?) {
if closed {
promise?.fail(error: ChannelError.alreadyClosed)
return

View File

@ -116,4 +116,11 @@ final class Socket : BaseSocket {
return try LinuxSocket.sendmmsg(sockfd: self.descriptor, msgvec: msgs.baseAddress!, vlen: CUnsignedInt(msgs.count), flags: 0)
}
#endif
func shutdown(how: Shutdown) throws {
guard self.open else {
throw IOError(errnoCode: EBADF, reason: "can't shutdown socket as it's not open anymore.")
}
try Posix.shutdown(descriptor: self.descriptor, how: how)
}
}

View File

@ -76,6 +76,23 @@ internal func wrapSyscall<T: FixedWidthInteger>(_ fn: () throws -> T, where func
}
}
enum Shutdown {
case RD
case WR
case RDWR
fileprivate var cValue: CInt {
switch self {
case .RD:
return CInt(SHUT_RD)
case .WR:
return CInt(SHUT_WR)
case .RDWR:
return CInt(SHUT_RDWR)
}
}
}
internal enum Posix {
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
static let SOCK_STREAM: CInt = CInt(Darwin.SOCK_STREAM)
@ -91,7 +108,18 @@ internal enum Posix {
fatalError("unsupported OS")
}
#endif
@inline(never)
public static func shutdown(descriptor: Int32, how: Shutdown) throws {
_ = try wrapSyscall({ () -> Int in
#if os(Linux)
return Int(Glibc.shutdown(descriptor, how.cValue))
#else
return Int(Darwin.shutdown(descriptor, how.cValue))
#endif
})
}
@inline(never)
public static func close(descriptor: Int32) throws {
_ = try wrapSyscall({ () -> Int in

View File

@ -140,7 +140,13 @@ public class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandler {
doUnbufferWrites(ctx: ctx)
}
public func close(ctx: ChannelHandlerContext, promise: EventLoopPromise<Void>?) {
public func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
guard mode == .all else {
// TODO: Support also other modes ?
promise?.fail(error: ChannelError.operationUnsupported)
return
}
switch state {
case .closing:
// We're in the process of TLS shutdown, so let's let that happen. However,

View File

@ -47,6 +47,9 @@ extension ChannelTests {
("testPendingWritesMoreThanWritevIOVectorLimit", testPendingWritesMoreThanWritevIOVectorLimit),
("testPendingWritesIsHappyWhenSendfileReturnsWouldBlockButWroteFully", testPendingWritesIsHappyWhenSendfileReturnsWouldBlockButWroteFully),
("testConnectTimeout", testConnectTimeout),
("testCloseOutput", testCloseOutput),
("testCloseInput", testCloseInput),
("testHalfClosure", testHalfClosure),
]
}
}

View File

@ -613,7 +613,7 @@ public class ChannelTests: XCTestCase {
promiseStates: [[false, false, false], [false, false, false]])
XCTAssertEqual(WriteResult.wouldBlock, result)
pwm.failAll(error: ChannelError.operationUnsupported)
pwm.failAll(error: ChannelError.operationUnsupported, close: true)
XCTAssertTrue(ps.map { $0.futureResult.fulfilled }.reduce(true) { $0 && $1 })
}
@ -939,7 +939,7 @@ public class ChannelTests: XCTestCase {
_ = pwm.add(data: .byteBuffer(buffer), promise: ps[2])
ps[0].futureResult.whenComplete { _ in
pwm.failAll(error: ChannelError.eof)
pwm.failAll(error: ChannelError.inputClosed, close: true)
}
let result = try assertExpectedWritability(pendingWritesManager: pwm,
@ -1031,4 +1031,194 @@ public class ChannelTests: XCTestCase {
}
}
}
func testCloseOutput() throws {
let group = MultiThreadedEventLoopGroup(numThreads: 1)
defer {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}
let server = try ServerSocket(protocolFamily: PF_INET)
defer {
XCTAssertNoThrow(try server.close())
}
try server.bind(to: SocketAddress.newAddressResolving(host: "127.0.0.1", port: 0))
try server.listen()
let byteCountingHandler = ByteCountingHandler(numBytes: 4, promise: group.next().newPromise())
let future = ClientBootstrap(group: group)
.channelInitializer { channel in
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: false, outputShutdown: true)).then { _ in
return channel.pipeline.add(handler: byteCountingHandler)
}
}
.connect(to: server.localAddress!)
let accepted = try server.accept()!
defer {
XCTAssertNoThrow(try accepted.close())
}
let channel = try future.wait()
defer {
XCTAssertNoThrow(try channel.close(mode: .all).wait())
}
var buffer = channel.allocator.buffer(capacity: 12)
buffer.write(string: "1234")
try channel.writeAndFlush(data: NIOAny(buffer)).wait()
try channel.close(mode: .output).wait()
do {
try channel.writeAndFlush(data: NIOAny(buffer)).wait()
XCTFail()
} catch let err as ChannelError {
XCTAssertEqual(ChannelError.outputClosed, err)
}
let written = try buffer.withUnsafeReadableBytes { p in
try accepted.write(pointer: p.baseAddress!.assumingMemoryBound(to: UInt8.self), size: 4)
}
switch written {
case .processed(let numBytes):
XCTAssertEqual(4, numBytes)
default:
XCTFail()
}
try byteCountingHandler.assertReceived(buffer: buffer)
}
func testCloseInput() throws {
let group = MultiThreadedEventLoopGroup(numThreads: 1)
defer {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}
let server = try ServerSocket(protocolFamily: PF_INET)
defer {
XCTAssertNoThrow(try server.close())
}
try server.bind(to: SocketAddress.newAddressResolving(host: "127.0.0.1", port: 0))
try server.listen()
class VerifyNoReadHandler : ChannelInboundHandler {
typealias InboundIn = ByteBuffer
public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
XCTFail("Received data: \(data)")
}
}
let future = ClientBootstrap(group: group)
.channelInitializer { channel in
return channel.pipeline.add(handler: VerifyNoReadHandler()).then { _ in
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false))
}
}
.connect(to: server.localAddress!)
let accepted = try server.accept()!
defer {
XCTAssertNoThrow(try accepted.close())
}
let channel = try future.wait()
defer {
XCTAssertNoThrow(try channel.close(mode: .all).wait())
}
try channel.close(mode: .input).wait()
var buffer = channel.allocator.buffer(capacity: 12)
buffer.write(string: "1234")
let written = try buffer.withUnsafeReadableBytes { p in
try accepted.write(pointer: p.baseAddress!.assumingMemoryBound(to: UInt8.self), size: 4)
}
switch written {
case .processed(let numBytes):
XCTAssertEqual(4, numBytes)
default:
XCTFail()
}
try channel.eventLoop.submit {
// Dummy task execution to give some time for an actual read to open (which should not happen as we closed the input).
}.wait()
}
func testHalfClosure() throws {
let group = MultiThreadedEventLoopGroup(numThreads: 1)
defer {
try! group.syncShutdownGracefully()
}
let server = try ServerSocket(protocolFamily: PF_INET)
defer {
XCTAssertNoThrow(try server.close())
}
try server.bind(to: SocketAddress.newAddressResolving(host: "127.0.0.1", port: 0))
try server.listen()
let future = ClientBootstrap(group: group)
.channelInitializer { channel in
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false))
}
.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
.connect(to: server.localAddress!)
let accepted = try server.accept()!
defer {
XCTAssertNoThrow(try accepted.close())
}
let channel = try future.wait()
defer {
XCTAssertNoThrow(try channel.close(mode: .all).wait())
}
try accepted.shutdown(how: .WR)
var buffer = channel.allocator.buffer(capacity: 12)
buffer.write(string: "1234")
try channel.writeAndFlush(data: NIOAny(buffer)).wait()
}
private class ShutdownVerificationHandler: ChannelInboundHandler {
typealias InboundIn = ByteBuffer
private var inputShutdownEventReceived = false
private var outputShutdownEventReceived = false
private let inputShutdown: Bool
private let outputShutdown: Bool
init(inputShutdown: Bool, outputShutdown: Bool) {
self.inputShutdown = inputShutdown
self.outputShutdown = outputShutdown
}
public func userInboundEventTriggered(ctx: ChannelHandlerContext, event: Any) {
switch event {
case let ev as ChannelEvent:
switch ev {
case .inputClosed:
XCTAssertFalse(inputShutdownEventReceived)
inputShutdownEventReceived = true
case .outputClosed:
XCTAssertFalse(outputShutdownEventReceived)
outputShutdownEventReceived = true
}
fallthrough
default:
ctx.fireUserInboundEventTriggered(event: event)
}
}
public func channelInactive(ctx: ChannelHandlerContext) {
XCTAssertEqual(inputShutdown, inputShutdownEventReceived)
XCTAssertEqual(outputShutdown, outputShutdownEventReceived)
}
}
}

View File

@ -251,38 +251,6 @@ class EchoServerClientTest : XCTestCase {
try countingHandler.assertReceived(buffer: buffer)
}
private final class ByteCountingHandler : ChannelInboundHandler {
typealias InboundIn = ByteBuffer
private let numBytes: Int
private let promise: EventLoopPromise<ByteBuffer>
private var buffer: ByteBuffer!
init(numBytes: Int, promise: EventLoopPromise<ByteBuffer>) {
self.numBytes = numBytes
self.promise = promise
}
func handlerAdded(ctx: ChannelHandlerContext) {
buffer = ctx.channel!.allocator.buffer(capacity: numBytes)
}
func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
var currentBuffer = self.unwrapInboundIn(data)
buffer.write(buffer: &currentBuffer)
if buffer.readableBytes == numBytes {
// Do something
promise.succeed(result: buffer)
}
}
func assertReceived(buffer: ByteBuffer) throws {
let received = try promise.futureResult.wait()
XCTAssertEqual(buffer, received)
}
}
private final class ChannelActiveHandler: ChannelInboundHandler {
typealias InboundIn = ByteBuffer
@ -447,7 +415,6 @@ class EchoServerClientTest : XCTestCase {
}
func testCloseInInactive() throws {
let group = MultiThreadedEventLoopGroup(numThreads: 1)
defer {
XCTAssertNoThrow(try group.syncShutdownGracefully())

View File

@ -111,40 +111,6 @@ class FileRegionTest : XCTestCase {
try futures.forEach { try $0.wait() }
}
private final class ByteCountingHandler : ChannelInboundHandler {
typealias InboundIn = ByteBuffer
private let numBytes: Int
private let promise: EventLoopPromise<ByteBuffer>
private var buffer: ByteBuffer!
init(numBytes: Int, promise: EventLoopPromise<ByteBuffer>) {
self.numBytes = numBytes
self.promise = promise
}
func handlerAdded(ctx: ChannelHandlerContext) {
buffer = ctx.channel!.allocator.buffer(capacity: numBytes)
if self.numBytes == 0 {
self.promise.succeed(result: buffer)
}
}
func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
var currentBuffer = self.unwrapInboundIn(data)
buffer.write(buffer: &currentBuffer)
if buffer.readableBytes == numBytes {
promise.succeed(result: buffer)
}
}
func assertReceived(buffer: ByteBuffer) throws {
let received = try promise.futureResult.wait()
XCTAssertEqual(buffer, received)
}
}
func testOutstandingFileRegionsWork() throws {
let group = MultiThreadedEventLoopGroup(numThreads: 1)
defer {

View File

@ -0,0 +1,50 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIO
import XCTest
final class ByteCountingHandler : ChannelInboundHandler {
typealias InboundIn = ByteBuffer
private let numBytes: Int
private let promise: EventLoopPromise<ByteBuffer>
private var buffer: ByteBuffer!
init(numBytes: Int, promise: EventLoopPromise<ByteBuffer>) {
self.numBytes = numBytes
self.promise = promise
}
func handlerAdded(ctx: ChannelHandlerContext) {
buffer = ctx.channel!.allocator.buffer(capacity: numBytes)
if self.numBytes == 0 {
self.promise.succeed(result: buffer)
}
}
func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
var currentBuffer = self.unwrapInboundIn(data)
buffer.write(buffer: &currentBuffer)
if buffer.readableBytes == numBytes {
promise.succeed(result: buffer)
}
}
func assertReceived(buffer: ByteBuffer) throws {
let received = try promise.futureResult.wait()
XCTAssertEqual(buffer, received)
}
}