Add support for closing only the output / input side of the Channel.
This commit is contained in:
parent
cf790be04f
commit
4d8cf1ff22
|
@ -462,8 +462,11 @@ private struct PendingWritesState {
|
|||
}
|
||||
|
||||
/// Fail all the outstanding writes. This is useful if for example the `Channel` is closed.
|
||||
func failAll(error: Error) {
|
||||
self.closed = true
|
||||
func failAll(error: Error, close: Bool) {
|
||||
if close {
|
||||
assert(!self.closed)
|
||||
self.closed = true
|
||||
}
|
||||
|
||||
self.state.failAll(error: error)()
|
||||
|
||||
|
@ -527,6 +530,9 @@ final class SocketChannel: BaseSocketChannel<Socket> {
|
|||
|
||||
private var connectTimeout = TimeAmount.seconds(10)
|
||||
private var connectTimeoutScheduled: Scheduled<Void>?
|
||||
private var allowRemoteHalfClosure: Bool = false
|
||||
private var inputShutdown: Bool = false
|
||||
private var outputShutdown: Bool = false
|
||||
|
||||
init(eventLoop: SelectableEventLoop, protocolFamily: Int32) throws {
|
||||
let socket = try Socket(protocolFamily: protocolFamily)
|
||||
|
@ -544,6 +550,8 @@ final class SocketChannel: BaseSocketChannel<Socket> {
|
|||
switch option {
|
||||
case _ as ConnectTimeoutOption:
|
||||
connectTimeout = value as! TimeAmount
|
||||
case _ as AllowRemoteHalfClosureOption:
|
||||
allowRemoteHalfClosure = value as! Bool
|
||||
default:
|
||||
try super.setOption0(option: option, value: value)
|
||||
}
|
||||
|
@ -554,6 +562,8 @@ final class SocketChannel: BaseSocketChannel<Socket> {
|
|||
switch option {
|
||||
case _ as ConnectTimeoutOption:
|
||||
return connectTimeout as! T.OptionType
|
||||
case _ as AllowRemoteHalfClosureOption:
|
||||
return allowRemoteHalfClosure as! T.OptionType
|
||||
default:
|
||||
return try super.getOption0(option: option)
|
||||
}
|
||||
|
@ -578,7 +588,7 @@ final class SocketChannel: BaseSocketChannel<Socket> {
|
|||
var buffer = recvAllocator.buffer(allocator: allocator)
|
||||
var result = ReadResult.none
|
||||
for i in 1...maxMessagesPerRead {
|
||||
if closed {
|
||||
if closed || inputShutdown {
|
||||
return result
|
||||
}
|
||||
switch try buffer.withMutableWritePointer(body: self.socket.read(pointer:size:)) {
|
||||
|
@ -600,6 +610,12 @@ final class SocketChannel: BaseSocketChannel<Socket> {
|
|||
}
|
||||
result = .some
|
||||
} else {
|
||||
if inputShutdown {
|
||||
// We received a EOF because we called shutdown on the fd by ourself, unregister from the Selector and return
|
||||
readPending = false
|
||||
unregisterForReadable()
|
||||
return result
|
||||
}
|
||||
// end-of-file
|
||||
throw ChannelError.eof
|
||||
}
|
||||
|
@ -659,7 +675,7 @@ final class SocketChannel: BaseSocketChannel<Socket> {
|
|||
connectTimeoutScheduled = eventLoop.scheduleTask(in: timeout) { () -> (Void) in
|
||||
if self.pendingConnect != nil {
|
||||
// The connection was still not established, close the Channel which will also fail the pending promise.
|
||||
self.close0(error: ChannelError.connectTimeout(timeout), promise: nil)
|
||||
self.close0(error: ChannelError.connectTimeout(timeout), mode: .all, promise: nil)
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
@ -675,13 +691,76 @@ final class SocketChannel: BaseSocketChannel<Socket> {
|
|||
becomeActive0()
|
||||
}
|
||||
|
||||
override func close0(error: Error, promise: EventLoopPromise<Void>?) {
|
||||
if let timeout = connectTimeoutScheduled {
|
||||
connectTimeoutScheduled = nil
|
||||
timeout.cancel()
|
||||
override func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?) {
|
||||
do {
|
||||
switch mode {
|
||||
case .output:
|
||||
if outputShutdown {
|
||||
promise?.fail(error: ChannelError.outputClosed)
|
||||
return
|
||||
}
|
||||
try socket.shutdown(how: .WR)
|
||||
outputShutdown = true
|
||||
// Fail all pending writes and so ensure all pending promises are notified
|
||||
pendingWrites.failAll(error: error, close: false)
|
||||
unregisterForWritable()
|
||||
promise?.succeed(result: ())
|
||||
|
||||
pipeline.fireUserInboundEventTriggered(event: ChannelEvent.outputClosed)
|
||||
|
||||
case .input:
|
||||
if inputShutdown {
|
||||
promise?.fail(error: ChannelError.inputClosed)
|
||||
return
|
||||
}
|
||||
switch error {
|
||||
case ChannelError.eof:
|
||||
// No need to explicit call socket.shutdown(...) as we received an EOF and the call would only cause
|
||||
// ENOTCON
|
||||
break
|
||||
default:
|
||||
try socket.shutdown(how: .RD)
|
||||
}
|
||||
inputShutdown = true
|
||||
unregisterForReadable()
|
||||
promise?.succeed(result: ())
|
||||
|
||||
pipeline.fireUserInboundEventTriggered(event: ChannelEvent.inputClosed)
|
||||
case .all:
|
||||
if let timeout = connectTimeoutScheduled {
|
||||
connectTimeoutScheduled = nil
|
||||
timeout.cancel()
|
||||
}
|
||||
super.close0(error: error, mode: mode, promise: promise)
|
||||
}
|
||||
} catch let err {
|
||||
promise?.fail(error: err)
|
||||
}
|
||||
super.close0(error: error, promise: promise)
|
||||
}
|
||||
|
||||
@discardableResult override func readIfNeeded0() -> Bool {
|
||||
if inputShutdown {
|
||||
return false
|
||||
}
|
||||
return super.readIfNeeded0()
|
||||
}
|
||||
|
||||
override public func read0(promise: EventLoopPromise<Void>?) {
|
||||
if inputShutdown {
|
||||
promise?.fail(error: ChannelError.inputClosed)
|
||||
return
|
||||
}
|
||||
super.read0(promise: promise)
|
||||
}
|
||||
|
||||
override public func write0(data: IOData, promise: EventLoopPromise<Void>?) {
|
||||
if outputShutdown {
|
||||
promise?.fail(error: ChannelError.outputClosed)
|
||||
return
|
||||
}
|
||||
super.write0(data: data, promise: promise)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// A `Channel` for a server socket.
|
||||
|
@ -773,7 +852,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
|
|||
}
|
||||
|
||||
override fileprivate func writeToSocket(pendingWrites: PendingWritesManager) throws -> WriteResult {
|
||||
pendingWrites.failAll(error: ChannelError.operationUnsupported)
|
||||
pendingWrites.failAll(error: ChannelError.operationUnsupported, close: false)
|
||||
return .writtenCompletely
|
||||
}
|
||||
|
||||
|
@ -805,7 +884,7 @@ public protocol ChannelCore : class {
|
|||
func write0(data: IOData, promise: EventLoopPromise<Void>?)
|
||||
func flush0(promise: EventLoopPromise<Void>?)
|
||||
func read0(promise: EventLoopPromise<Void>?)
|
||||
func close0(error: Error, promise: EventLoopPromise<Void>?)
|
||||
func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?)
|
||||
func triggerUserOutboundEvent0(event: Any, promise: EventLoopPromise<Void>?)
|
||||
func channelRead0(data: NIOAny)
|
||||
func errorCaught0(error: Error)
|
||||
|
@ -920,8 +999,8 @@ extension Channel {
|
|||
pipeline.read(promise: promise)
|
||||
}
|
||||
|
||||
public func close(promise: EventLoopPromise<Void>?) {
|
||||
pipeline.close(promise: promise)
|
||||
public func close(mode: CloseMode = .all, promise: EventLoopPromise<Void>?) {
|
||||
pipeline.close(mode: mode, promise: promise)
|
||||
}
|
||||
|
||||
public func register(promise: EventLoopPromise<Void>?) {
|
||||
|
@ -953,12 +1032,13 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
|||
let socket: T
|
||||
public var interestedEvent: IOEvent = .none
|
||||
|
||||
/// `true` if the whole `Channel` is closed and so no more IO operation can be done.
|
||||
public final var closed: Bool {
|
||||
assert(eventLoop.inEventLoop)
|
||||
return pendingWrites.closed
|
||||
}
|
||||
|
||||
private let pendingWrites: PendingWritesManager
|
||||
fileprivate let pendingWrites: PendingWritesManager
|
||||
fileprivate var readPending = false
|
||||
private var neverRegistered = true
|
||||
fileprivate var pendingConnect: EventLoopPromise<Void>?
|
||||
|
@ -1089,7 +1169,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
|||
/// Triggers a `ChannelPipeline.read()` if `autoRead` is enabled.`
|
||||
///
|
||||
/// - returns: `true` if `readPending` is `true`, `false` otherwise.
|
||||
@discardableResult final func readIfNeeded0() -> Bool {
|
||||
@discardableResult func readIfNeeded0() -> Bool {
|
||||
assert(eventLoop.inEventLoop)
|
||||
|
||||
if !readPending && autoRead {
|
||||
|
@ -1136,9 +1216,8 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
|||
}
|
||||
}
|
||||
|
||||
private func unregisterForWritable() {
|
||||
fileprivate func unregisterForWritable() {
|
||||
assert(eventLoop.inEventLoop)
|
||||
|
||||
switch interestedEvent {
|
||||
case .all:
|
||||
safeReregister(interested: .read)
|
||||
|
@ -1165,7 +1244,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
|||
}
|
||||
}
|
||||
|
||||
public final func read0(promise: EventLoopPromise<Void>?) {
|
||||
public func read0(promise: EventLoopPromise<Void>?) {
|
||||
assert(eventLoop.inEventLoop)
|
||||
|
||||
if closed {
|
||||
|
@ -1201,7 +1280,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
|||
}
|
||||
}
|
||||
|
||||
private func unregisterForReadable() {
|
||||
fileprivate func unregisterForReadable() {
|
||||
assert(eventLoop.inEventLoop)
|
||||
|
||||
switch interestedEvent {
|
||||
|
@ -1214,7 +1293,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
|||
}
|
||||
}
|
||||
|
||||
public func close0(error: Error, promise: EventLoopPromise<Void>?) {
|
||||
public func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?) {
|
||||
assert(eventLoop.inEventLoop)
|
||||
|
||||
if closed {
|
||||
|
@ -1222,6 +1301,11 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
|||
return
|
||||
}
|
||||
|
||||
guard mode == .all else {
|
||||
promise?.fail(error: ChannelError.operationUnsupported)
|
||||
return
|
||||
}
|
||||
|
||||
interestedEvent = .none
|
||||
do {
|
||||
try selectableEventLoop.deregister(channel: self)
|
||||
|
@ -1237,7 +1321,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
|||
}
|
||||
|
||||
// Fail all pending writes and so ensure all pending promises are notified
|
||||
self.pendingWrites.failAll(error: error)
|
||||
self.pendingWrites.failAll(error: error, close: true)
|
||||
|
||||
becomeInactive0()
|
||||
|
||||
|
@ -1323,25 +1407,30 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
|||
do {
|
||||
try readFromSocket()
|
||||
} catch let err {
|
||||
if let channelErr = err as? ChannelError {
|
||||
// EOF is not really an error that should be forwarded to the user
|
||||
if channelErr != ChannelError.eof {
|
||||
pipeline.fireErrorCaught0(error: err)
|
||||
}
|
||||
// ChannelError.eof is not something we want to fire through the pipeline as it just means the remote
|
||||
// per closed / shutdown the connection.
|
||||
if let channelErr = err as? ChannelError, channelErr != ChannelError.eof {
|
||||
pipeline.fireErrorCaught0(error: err)
|
||||
} else if try! getOption(option: ChannelOptions.allowRemoteHalfClosure) {
|
||||
// If we want to allow half closure we will just mark the input side of the Channel
|
||||
// as closed.
|
||||
pipeline.fireChannelReadComplete0()
|
||||
close0(error: err, mode: .input, promise: nil)
|
||||
readPending = false
|
||||
return
|
||||
} else {
|
||||
pipeline.fireErrorCaught0(error: err)
|
||||
}
|
||||
|
||||
|
||||
// Call before triggering the close of the Channel.
|
||||
pipeline.fireChannelReadComplete0()
|
||||
close0(error: err, promise: nil)
|
||||
|
||||
close0(error: err, mode: .all, promise: nil)
|
||||
return
|
||||
}
|
||||
pipeline.fireChannelReadComplete0()
|
||||
readIfNeeded0()
|
||||
}
|
||||
|
||||
|
||||
fileprivate func connectSocket(to address: SocketAddress) throws -> Bool {
|
||||
fatalError("this must be overridden by sub class")
|
||||
}
|
||||
|
@ -1424,7 +1513,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
|||
try selectableEventLoop.reregister(channel: self)
|
||||
} catch let err {
|
||||
pipeline.fireErrorCaught0(error: err)
|
||||
close0(error: err, promise: nil)
|
||||
close0(error: err, mode: .all, promise: nil)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1441,7 +1530,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
|||
return true
|
||||
} catch let err {
|
||||
pipeline.fireErrorCaught0(error: err)
|
||||
close0(error: err, promise: nil)
|
||||
close0(error: err, mode: .all, promise: nil)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
@ -1485,7 +1574,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
|
|||
}
|
||||
}
|
||||
|
||||
close0(error: err, promise: nil)
|
||||
close0(error: err, mode: .all, promise: nil)
|
||||
|
||||
// we handled all writes
|
||||
return true
|
||||
|
@ -1537,6 +1626,12 @@ public enum ChannelError: Error {
|
|||
|
||||
/// Close was called on a channel that is already closed.
|
||||
case alreadyClosed
|
||||
|
||||
/// Output-side of the channel is closed.
|
||||
case outputClosed
|
||||
|
||||
/// Input-side of the channel is closed.
|
||||
case inputClosed
|
||||
|
||||
/// A read operation reached end-of-file. This usually means the remote peer closed the socket but it's still
|
||||
/// open locally.
|
||||
|
@ -1556,6 +1651,10 @@ extension ChannelError: Equatable {
|
|||
return true
|
||||
case (.alreadyClosed, .alreadyClosed):
|
||||
return true
|
||||
case (.outputClosed, .outputClosed):
|
||||
return true
|
||||
case (.inputClosed, .inputClosed):
|
||||
return true
|
||||
case (.eof, .eof):
|
||||
return true
|
||||
default:
|
||||
|
@ -1564,3 +1663,11 @@ extension ChannelError: Equatable {
|
|||
}
|
||||
}
|
||||
|
||||
/// An `Channel` related event that is passed through the `ChannelPipeline` to notify the user.
|
||||
public enum ChannelEvent: Equatable {
|
||||
/// `ChannelOptions.allowRemoteHalfClosure` is `true` and input portion of the `Channel` was closed.
|
||||
case inputClosed
|
||||
/// Output portion of the `Channel` was closed.
|
||||
case outputClosed
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ public protocol _ChannelOutboundHandler : ChannelHandler {
|
|||
func flush(ctx: ChannelHandlerContext, promise: EventLoopPromise<Void>?)
|
||||
// TODO: Think about make this more flexible in terms of influence the allocation that is used to read the next amount of data
|
||||
func read(ctx: ChannelHandlerContext, promise: EventLoopPromise<Void>?)
|
||||
func close(ctx: ChannelHandlerContext, promise: EventLoopPromise<Void>?)
|
||||
func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?)
|
||||
func triggerUserOutboundEvent(ctx: ChannelHandlerContext, event: Any, promise: EventLoopPromise<Void>?)
|
||||
}
|
||||
|
||||
|
@ -79,8 +79,8 @@ extension _ChannelOutboundHandler {
|
|||
ctx.read(promise: promise)
|
||||
}
|
||||
|
||||
public func close(ctx: ChannelHandlerContext, promise: EventLoopPromise<Void>?) {
|
||||
ctx.close(promise: promise)
|
||||
public func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
|
||||
ctx.close(mode: mode, promise: promise)
|
||||
}
|
||||
|
||||
public func triggerUserOutboundEvent(ctx: ChannelHandlerContext, event: Any, promise: EventLoopPromise<Void>?) {
|
||||
|
|
|
@ -121,15 +121,18 @@ public protocol ChannelOutboundInvoker {
|
|||
|
||||
/// Close the `Channel` and so the connection if one exists.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - mode: the `CloseMode` that is used
|
||||
/// - returns: the future which will be notified once the operation completes.
|
||||
func close() -> EventLoopFuture<Void>
|
||||
func close(mode: CloseMode) -> EventLoopFuture<Void>
|
||||
|
||||
/// Close the `Channel` and so the connection if one exists.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - mode: the `CloseMode` that is used
|
||||
/// - promise: the `EventLoopPromise` that will be notified once the operation completes,
|
||||
/// or `nil` if not interested in the outcome of the operation.
|
||||
func close(promise: EventLoopPromise<Void>?)
|
||||
func close(mode: CloseMode, promise: EventLoopPromise<Void>?)
|
||||
|
||||
/// Trigger a custom user outbound event which will flow through the `ChannelPipeline`.
|
||||
///
|
||||
|
@ -192,9 +195,9 @@ extension ChannelOutboundInvoker {
|
|||
return promise.futureResult
|
||||
}
|
||||
|
||||
public func close() -> EventLoopFuture<Void> {
|
||||
public func close(mode: CloseMode = .all) -> EventLoopFuture<Void> {
|
||||
let promise = newPromise()
|
||||
close(promise: promise)
|
||||
close(mode: mode, promise: promise)
|
||||
return promise.futureResult
|
||||
}
|
||||
|
||||
|
@ -267,3 +270,17 @@ public protocol ChannelInboundInvoker {
|
|||
|
||||
/// A protocol that signals that outbound and inbound events are triggered by this invoker.
|
||||
public protocol ChannelInvoker : ChannelOutboundInvoker, ChannelInboundInvoker { }
|
||||
|
||||
/// Specify what kind of close operation is requested.
|
||||
public enum CloseMode {
|
||||
/// Close the output (writing) side of the `Channel` without closing the actual file descriptor.
|
||||
/// This is an optional mode which means it may not be supported by all `Channel` implementations.
|
||||
case output
|
||||
|
||||
/// Close the input (reading) side of the `Channel` without closing the actual file descriptor.
|
||||
/// This is an optional mode which means it may not be supported by all `Channel` implementations.
|
||||
case input
|
||||
|
||||
/// Close the whole `Channel (file descriptor).
|
||||
case all
|
||||
}
|
||||
|
|
|
@ -147,6 +147,17 @@ public enum ConnectTimeoutOption: ChannelOption {
|
|||
case const(())
|
||||
}
|
||||
|
||||
/// `AllowRemoteHalfClosureOption` allows users to configure whether the `Channel` will close itself when its remote
|
||||
/// peer shuts down its send stream, or whether it will remain open. If set to `false` (the default), the `Channel`
|
||||
/// will be closed automatically if the remote peer shuts down its send stream. If set to true, the `Channel` will
|
||||
/// not be closed: instead, a `ChannelEvent.inboundClosed` user event will be sent on the `ChannelPipeline`,
|
||||
/// and no more data will be received.
|
||||
public enum AllowRemoteHalfClosureOption: ChannelOption {
|
||||
public typealias AssociatedValueType = ()
|
||||
public typealias OptionType = Bool
|
||||
|
||||
case const(())
|
||||
}
|
||||
|
||||
/// Provides `ChannelOption`s to be used with a `Channel`, `Bootstrap` or `ServerBootstrap`.
|
||||
public struct ChannelOptions {
|
||||
|
@ -176,4 +187,7 @@ public struct ChannelOptions {
|
|||
|
||||
/// - seealso: `ConnectTimeoutOption`.
|
||||
public static let connectTimeout = ConnectTimeoutOption.const(())
|
||||
|
||||
/// - seealso: `AllowRemoteHalfClosureOption`.
|
||||
public static let allowRemoteHalfClosure = AllowRemoteHalfClosureOption.const(())
|
||||
}
|
||||
|
|
|
@ -478,12 +478,12 @@ public final class ChannelPipeline : ChannelInvoker {
|
|||
}
|
||||
}
|
||||
|
||||
public func close(promise: EventLoopPromise<Void>?) {
|
||||
public func close(mode: CloseMode = .all, promise: EventLoopPromise<Void>?) {
|
||||
if eventLoop.inEventLoop {
|
||||
close0(promise: promise)
|
||||
close0(mode: mode, promise: promise)
|
||||
} else {
|
||||
eventLoop.execute {
|
||||
self.close0(promise: promise)
|
||||
self.close0(mode: mode, promise: promise)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -578,9 +578,9 @@ public final class ChannelPipeline : ChannelInvoker {
|
|||
return self.head?.inboundNext
|
||||
}
|
||||
|
||||
func close0(promise: EventLoopPromise<Void>?) {
|
||||
func close0(mode: CloseMode, promise: EventLoopPromise<Void>?) {
|
||||
if let firstOutboundCtx = firstOutboundCtx {
|
||||
firstOutboundCtx.invokeClose(promise: promise)
|
||||
firstOutboundCtx.invokeClose(mode: mode, promise: promise)
|
||||
} else {
|
||||
promise?.fail(error: ChannelError.alreadyClosed)
|
||||
}
|
||||
|
@ -768,9 +768,9 @@ private final class HeadChannelHandler : _ChannelOutboundHandler {
|
|||
}
|
||||
}
|
||||
|
||||
func close(ctx: ChannelHandlerContext, promise: EventLoopPromise<Void>?) {
|
||||
func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
|
||||
if let channel = ctx.channel {
|
||||
channel._unsafe.close0(error: ChannelError.alreadyClosed, promise: promise)
|
||||
channel._unsafe.close0(error: mode.error, mode: mode, promise: promise)
|
||||
} else {
|
||||
promise?.fail(error: ChannelError.alreadyClosed)
|
||||
}
|
||||
|
@ -791,6 +791,20 @@ private final class HeadChannelHandler : _ChannelOutboundHandler {
|
|||
promise?.fail(error: ChannelError.ioOnClosedChannel)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private extension CloseMode {
|
||||
var error: ChannelError {
|
||||
switch self {
|
||||
case .all:
|
||||
return ChannelError.alreadyClosed
|
||||
case .output:
|
||||
return ChannelError.outputClosed
|
||||
case .input:
|
||||
return ChannelError.inputClosed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Special `ChannelInboundHandler` which will consume all inbound events.
|
||||
|
@ -1070,10 +1084,11 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
/// When the `close` event reaches the `HeadChannelHandler` the socket will be closed.
|
||||
///
|
||||
/// - parameters:
|
||||
/// - mode: The `CloseMode` to use.
|
||||
/// - promise: The promise fulfilled when the `Channel` has been closed or failed if it the closing failed.
|
||||
public func close(promise: EventLoopPromise<Void>?) {
|
||||
public func close(mode: CloseMode = .all, promise: EventLoopPromise<Void>?) {
|
||||
if let outboundNext = outboundNext {
|
||||
outboundNext.invokeClose(promise: promise)
|
||||
outboundNext.invokeClose(mode: mode, promise: promise)
|
||||
} else {
|
||||
promise?.fail(error: ChannelError.alreadyClosed)
|
||||
}
|
||||
|
@ -1256,11 +1271,11 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
self.outboundHandler.read(ctx: self, promise: promise)
|
||||
}
|
||||
|
||||
fileprivate func invokeClose(promise: EventLoopPromise<Void>?) {
|
||||
fileprivate func invokeClose(mode: CloseMode, promise: EventLoopPromise<Void>?) {
|
||||
assert(inEventLoop)
|
||||
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
|
||||
|
||||
self.outboundHandler.close(ctx: self, promise: promise)
|
||||
self.outboundHandler.close(ctx: self, mode: mode, promise: promise)
|
||||
}
|
||||
|
||||
fileprivate func invokeTriggerUserOutboundEvent(event: Any, promise: EventLoopPromise<Void>?) {
|
||||
|
|
|
@ -173,7 +173,7 @@ class EmbeddedChannelCore : ChannelCore {
|
|||
var outboundBuffer: [IOData] = []
|
||||
var inboundBuffer: [NIOAny] = []
|
||||
|
||||
func close0(error: Error, promise: EventLoopPromise<Void>?) {
|
||||
func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?) {
|
||||
if closed {
|
||||
promise?.fail(error: ChannelError.alreadyClosed)
|
||||
return
|
||||
|
|
|
@ -116,4 +116,11 @@ final class Socket : BaseSocket {
|
|||
return try LinuxSocket.sendmmsg(sockfd: self.descriptor, msgvec: msgs.baseAddress!, vlen: CUnsignedInt(msgs.count), flags: 0)
|
||||
}
|
||||
#endif
|
||||
|
||||
func shutdown(how: Shutdown) throws {
|
||||
guard self.open else {
|
||||
throw IOError(errnoCode: EBADF, reason: "can't shutdown socket as it's not open anymore.")
|
||||
}
|
||||
try Posix.shutdown(descriptor: self.descriptor, how: how)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,6 +76,23 @@ internal func wrapSyscall<T: FixedWidthInteger>(_ fn: () throws -> T, where func
|
|||
}
|
||||
}
|
||||
|
||||
enum Shutdown {
|
||||
case RD
|
||||
case WR
|
||||
case RDWR
|
||||
|
||||
fileprivate var cValue: CInt {
|
||||
switch self {
|
||||
case .RD:
|
||||
return CInt(SHUT_RD)
|
||||
case .WR:
|
||||
return CInt(SHUT_WR)
|
||||
case .RDWR:
|
||||
return CInt(SHUT_RDWR)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal enum Posix {
|
||||
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
|
||||
static let SOCK_STREAM: CInt = CInt(Darwin.SOCK_STREAM)
|
||||
|
@ -91,7 +108,18 @@ internal enum Posix {
|
|||
fatalError("unsupported OS")
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@inline(never)
|
||||
public static func shutdown(descriptor: Int32, how: Shutdown) throws {
|
||||
_ = try wrapSyscall({ () -> Int in
|
||||
#if os(Linux)
|
||||
return Int(Glibc.shutdown(descriptor, how.cValue))
|
||||
#else
|
||||
return Int(Darwin.shutdown(descriptor, how.cValue))
|
||||
#endif
|
||||
})
|
||||
}
|
||||
|
||||
@inline(never)
|
||||
public static func close(descriptor: Int32) throws {
|
||||
_ = try wrapSyscall({ () -> Int in
|
||||
|
|
|
@ -140,7 +140,13 @@ public class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandler {
|
|||
doUnbufferWrites(ctx: ctx)
|
||||
}
|
||||
|
||||
public func close(ctx: ChannelHandlerContext, promise: EventLoopPromise<Void>?) {
|
||||
public func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
|
||||
guard mode == .all else {
|
||||
// TODO: Support also other modes ?
|
||||
promise?.fail(error: ChannelError.operationUnsupported)
|
||||
return
|
||||
}
|
||||
|
||||
switch state {
|
||||
case .closing:
|
||||
// We're in the process of TLS shutdown, so let's let that happen. However,
|
||||
|
|
|
@ -47,6 +47,9 @@ extension ChannelTests {
|
|||
("testPendingWritesMoreThanWritevIOVectorLimit", testPendingWritesMoreThanWritevIOVectorLimit),
|
||||
("testPendingWritesIsHappyWhenSendfileReturnsWouldBlockButWroteFully", testPendingWritesIsHappyWhenSendfileReturnsWouldBlockButWroteFully),
|
||||
("testConnectTimeout", testConnectTimeout),
|
||||
("testCloseOutput", testCloseOutput),
|
||||
("testCloseInput", testCloseInput),
|
||||
("testHalfClosure", testHalfClosure),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -613,7 +613,7 @@ public class ChannelTests: XCTestCase {
|
|||
promiseStates: [[false, false, false], [false, false, false]])
|
||||
XCTAssertEqual(WriteResult.wouldBlock, result)
|
||||
|
||||
pwm.failAll(error: ChannelError.operationUnsupported)
|
||||
pwm.failAll(error: ChannelError.operationUnsupported, close: true)
|
||||
|
||||
XCTAssertTrue(ps.map { $0.futureResult.fulfilled }.reduce(true) { $0 && $1 })
|
||||
}
|
||||
|
@ -939,7 +939,7 @@ public class ChannelTests: XCTestCase {
|
|||
_ = pwm.add(data: .byteBuffer(buffer), promise: ps[2])
|
||||
|
||||
ps[0].futureResult.whenComplete { _ in
|
||||
pwm.failAll(error: ChannelError.eof)
|
||||
pwm.failAll(error: ChannelError.inputClosed, close: true)
|
||||
}
|
||||
|
||||
let result = try assertExpectedWritability(pendingWritesManager: pwm,
|
||||
|
@ -1031,4 +1031,194 @@ public class ChannelTests: XCTestCase {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testCloseOutput() throws {
|
||||
let group = MultiThreadedEventLoopGroup(numThreads: 1)
|
||||
defer {
|
||||
XCTAssertNoThrow(try group.syncShutdownGracefully())
|
||||
}
|
||||
|
||||
let server = try ServerSocket(protocolFamily: PF_INET)
|
||||
defer {
|
||||
XCTAssertNoThrow(try server.close())
|
||||
}
|
||||
try server.bind(to: SocketAddress.newAddressResolving(host: "127.0.0.1", port: 0))
|
||||
try server.listen()
|
||||
|
||||
let byteCountingHandler = ByteCountingHandler(numBytes: 4, promise: group.next().newPromise())
|
||||
|
||||
let future = ClientBootstrap(group: group)
|
||||
.channelInitializer { channel in
|
||||
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: false, outputShutdown: true)).then { _ in
|
||||
return channel.pipeline.add(handler: byteCountingHandler)
|
||||
}
|
||||
}
|
||||
.connect(to: server.localAddress!)
|
||||
let accepted = try server.accept()!
|
||||
defer {
|
||||
XCTAssertNoThrow(try accepted.close())
|
||||
}
|
||||
|
||||
let channel = try future.wait()
|
||||
defer {
|
||||
XCTAssertNoThrow(try channel.close(mode: .all).wait())
|
||||
}
|
||||
|
||||
var buffer = channel.allocator.buffer(capacity: 12)
|
||||
buffer.write(string: "1234")
|
||||
|
||||
try channel.writeAndFlush(data: NIOAny(buffer)).wait()
|
||||
try channel.close(mode: .output).wait()
|
||||
|
||||
do {
|
||||
try channel.writeAndFlush(data: NIOAny(buffer)).wait()
|
||||
XCTFail()
|
||||
} catch let err as ChannelError {
|
||||
XCTAssertEqual(ChannelError.outputClosed, err)
|
||||
}
|
||||
let written = try buffer.withUnsafeReadableBytes { p in
|
||||
try accepted.write(pointer: p.baseAddress!.assumingMemoryBound(to: UInt8.self), size: 4)
|
||||
}
|
||||
switch written {
|
||||
case .processed(let numBytes):
|
||||
XCTAssertEqual(4, numBytes)
|
||||
default:
|
||||
XCTFail()
|
||||
}
|
||||
try byteCountingHandler.assertReceived(buffer: buffer)
|
||||
}
|
||||
|
||||
func testCloseInput() throws {
|
||||
let group = MultiThreadedEventLoopGroup(numThreads: 1)
|
||||
defer {
|
||||
XCTAssertNoThrow(try group.syncShutdownGracefully())
|
||||
}
|
||||
|
||||
let server = try ServerSocket(protocolFamily: PF_INET)
|
||||
defer {
|
||||
XCTAssertNoThrow(try server.close())
|
||||
}
|
||||
try server.bind(to: SocketAddress.newAddressResolving(host: "127.0.0.1", port: 0))
|
||||
try server.listen()
|
||||
|
||||
class VerifyNoReadHandler : ChannelInboundHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
|
||||
public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
|
||||
XCTFail("Received data: \(data)")
|
||||
}
|
||||
}
|
||||
|
||||
let future = ClientBootstrap(group: group)
|
||||
.channelInitializer { channel in
|
||||
return channel.pipeline.add(handler: VerifyNoReadHandler()).then { _ in
|
||||
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false))
|
||||
}
|
||||
}
|
||||
.connect(to: server.localAddress!)
|
||||
let accepted = try server.accept()!
|
||||
defer {
|
||||
XCTAssertNoThrow(try accepted.close())
|
||||
}
|
||||
|
||||
let channel = try future.wait()
|
||||
defer {
|
||||
XCTAssertNoThrow(try channel.close(mode: .all).wait())
|
||||
}
|
||||
|
||||
try channel.close(mode: .input).wait()
|
||||
|
||||
var buffer = channel.allocator.buffer(capacity: 12)
|
||||
buffer.write(string: "1234")
|
||||
|
||||
let written = try buffer.withUnsafeReadableBytes { p in
|
||||
try accepted.write(pointer: p.baseAddress!.assumingMemoryBound(to: UInt8.self), size: 4)
|
||||
}
|
||||
|
||||
switch written {
|
||||
case .processed(let numBytes):
|
||||
XCTAssertEqual(4, numBytes)
|
||||
default:
|
||||
XCTFail()
|
||||
}
|
||||
|
||||
try channel.eventLoop.submit {
|
||||
// Dummy task execution to give some time for an actual read to open (which should not happen as we closed the input).
|
||||
}.wait()
|
||||
}
|
||||
|
||||
func testHalfClosure() throws {
|
||||
let group = MultiThreadedEventLoopGroup(numThreads: 1)
|
||||
defer {
|
||||
try! group.syncShutdownGracefully()
|
||||
}
|
||||
|
||||
let server = try ServerSocket(protocolFamily: PF_INET)
|
||||
defer {
|
||||
XCTAssertNoThrow(try server.close())
|
||||
}
|
||||
try server.bind(to: SocketAddress.newAddressResolving(host: "127.0.0.1", port: 0))
|
||||
try server.listen()
|
||||
|
||||
let future = ClientBootstrap(group: group)
|
||||
.channelInitializer { channel in
|
||||
return channel.pipeline.add(handler: ShutdownVerificationHandler(inputShutdown: true, outputShutdown: false))
|
||||
}
|
||||
.channelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
|
||||
.connect(to: server.localAddress!)
|
||||
let accepted = try server.accept()!
|
||||
defer {
|
||||
XCTAssertNoThrow(try accepted.close())
|
||||
}
|
||||
|
||||
let channel = try future.wait()
|
||||
defer {
|
||||
XCTAssertNoThrow(try channel.close(mode: .all).wait())
|
||||
}
|
||||
|
||||
try accepted.shutdown(how: .WR)
|
||||
|
||||
var buffer = channel.allocator.buffer(capacity: 12)
|
||||
buffer.write(string: "1234")
|
||||
|
||||
try channel.writeAndFlush(data: NIOAny(buffer)).wait()
|
||||
}
|
||||
|
||||
private class ShutdownVerificationHandler: ChannelInboundHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
|
||||
private var inputShutdownEventReceived = false
|
||||
private var outputShutdownEventReceived = false
|
||||
|
||||
private let inputShutdown: Bool
|
||||
private let outputShutdown: Bool
|
||||
|
||||
init(inputShutdown: Bool, outputShutdown: Bool) {
|
||||
self.inputShutdown = inputShutdown
|
||||
self.outputShutdown = outputShutdown
|
||||
}
|
||||
|
||||
public func userInboundEventTriggered(ctx: ChannelHandlerContext, event: Any) {
|
||||
switch event {
|
||||
case let ev as ChannelEvent:
|
||||
switch ev {
|
||||
case .inputClosed:
|
||||
XCTAssertFalse(inputShutdownEventReceived)
|
||||
inputShutdownEventReceived = true
|
||||
case .outputClosed:
|
||||
XCTAssertFalse(outputShutdownEventReceived)
|
||||
outputShutdownEventReceived = true
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
ctx.fireUserInboundEventTriggered(event: event)
|
||||
}
|
||||
}
|
||||
|
||||
public func channelInactive(ctx: ChannelHandlerContext) {
|
||||
XCTAssertEqual(inputShutdown, inputShutdownEventReceived)
|
||||
XCTAssertEqual(outputShutdown, outputShutdownEventReceived)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -251,38 +251,6 @@ class EchoServerClientTest : XCTestCase {
|
|||
|
||||
try countingHandler.assertReceived(buffer: buffer)
|
||||
}
|
||||
|
||||
private final class ByteCountingHandler : ChannelInboundHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
|
||||
private let numBytes: Int
|
||||
private let promise: EventLoopPromise<ByteBuffer>
|
||||
private var buffer: ByteBuffer!
|
||||
|
||||
init(numBytes: Int, promise: EventLoopPromise<ByteBuffer>) {
|
||||
self.numBytes = numBytes
|
||||
self.promise = promise
|
||||
}
|
||||
|
||||
func handlerAdded(ctx: ChannelHandlerContext) {
|
||||
buffer = ctx.channel!.allocator.buffer(capacity: numBytes)
|
||||
}
|
||||
|
||||
func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
|
||||
var currentBuffer = self.unwrapInboundIn(data)
|
||||
buffer.write(buffer: ¤tBuffer)
|
||||
|
||||
if buffer.readableBytes == numBytes {
|
||||
// Do something
|
||||
promise.succeed(result: buffer)
|
||||
}
|
||||
}
|
||||
|
||||
func assertReceived(buffer: ByteBuffer) throws {
|
||||
let received = try promise.futureResult.wait()
|
||||
XCTAssertEqual(buffer, received)
|
||||
}
|
||||
}
|
||||
|
||||
private final class ChannelActiveHandler: ChannelInboundHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
|
@ -447,7 +415,6 @@ class EchoServerClientTest : XCTestCase {
|
|||
}
|
||||
|
||||
func testCloseInInactive() throws {
|
||||
|
||||
let group = MultiThreadedEventLoopGroup(numThreads: 1)
|
||||
defer {
|
||||
XCTAssertNoThrow(try group.syncShutdownGracefully())
|
||||
|
|
|
@ -111,40 +111,6 @@ class FileRegionTest : XCTestCase {
|
|||
try futures.forEach { try $0.wait() }
|
||||
}
|
||||
|
||||
private final class ByteCountingHandler : ChannelInboundHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
|
||||
private let numBytes: Int
|
||||
private let promise: EventLoopPromise<ByteBuffer>
|
||||
private var buffer: ByteBuffer!
|
||||
|
||||
init(numBytes: Int, promise: EventLoopPromise<ByteBuffer>) {
|
||||
self.numBytes = numBytes
|
||||
self.promise = promise
|
||||
}
|
||||
|
||||
func handlerAdded(ctx: ChannelHandlerContext) {
|
||||
buffer = ctx.channel!.allocator.buffer(capacity: numBytes)
|
||||
if self.numBytes == 0 {
|
||||
self.promise.succeed(result: buffer)
|
||||
}
|
||||
}
|
||||
|
||||
func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
|
||||
var currentBuffer = self.unwrapInboundIn(data)
|
||||
buffer.write(buffer: ¤tBuffer)
|
||||
|
||||
if buffer.readableBytes == numBytes {
|
||||
promise.succeed(result: buffer)
|
||||
}
|
||||
}
|
||||
|
||||
func assertReceived(buffer: ByteBuffer) throws {
|
||||
let received = try promise.futureResult.wait()
|
||||
XCTAssertEqual(buffer, received)
|
||||
}
|
||||
}
|
||||
|
||||
func testOutstandingFileRegionsWork() throws {
|
||||
let group = MultiThreadedEventLoopGroup(numThreads: 1)
|
||||
defer {
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
import NIO
|
||||
import XCTest
|
||||
|
||||
final class ByteCountingHandler : ChannelInboundHandler {
|
||||
typealias InboundIn = ByteBuffer
|
||||
|
||||
private let numBytes: Int
|
||||
private let promise: EventLoopPromise<ByteBuffer>
|
||||
private var buffer: ByteBuffer!
|
||||
|
||||
init(numBytes: Int, promise: EventLoopPromise<ByteBuffer>) {
|
||||
self.numBytes = numBytes
|
||||
self.promise = promise
|
||||
}
|
||||
|
||||
func handlerAdded(ctx: ChannelHandlerContext) {
|
||||
buffer = ctx.channel!.allocator.buffer(capacity: numBytes)
|
||||
if self.numBytes == 0 {
|
||||
self.promise.succeed(result: buffer)
|
||||
}
|
||||
}
|
||||
|
||||
func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
|
||||
var currentBuffer = self.unwrapInboundIn(data)
|
||||
buffer.write(buffer: ¤tBuffer)
|
||||
|
||||
if buffer.readableBytes == numBytes {
|
||||
promise.succeed(result: buffer)
|
||||
}
|
||||
}
|
||||
|
||||
func assertReceived(buffer: ByteBuffer) throws {
|
||||
let received = try promise.futureResult.wait()
|
||||
XCTAssertEqual(buffer, received)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue