baby steps towards a type-safe Channel pipeline
This commit is contained in:
parent
15677b9fc2
commit
c09178ceb2
|
@ -128,6 +128,7 @@ public final class ServerBootstrap {
|
|||
}
|
||||
|
||||
private class AcceptHandler : ChannelInboundHandler {
|
||||
public typealias InboundIn = SocketChannel
|
||||
|
||||
private let childHandler: ChannelHandler?
|
||||
private let childOptions: ChannelOptionStorage
|
||||
|
@ -138,7 +139,7 @@ public final class ServerBootstrap {
|
|||
}
|
||||
|
||||
func channelRead(ctx: ChannelHandlerContext, data: IOData) {
|
||||
let accepted = data.forceAsOther() as SocketChannel
|
||||
let accepted = self.unwrapInboundIn(data)
|
||||
do {
|
||||
try self.childOptions.applyAll(channel: accepted)
|
||||
|
||||
|
|
|
@ -21,33 +21,74 @@ import Glibc
|
|||
import Darwin
|
||||
#endif
|
||||
|
||||
public enum IOData {
|
||||
case byteBuffer(ByteBuffer)
|
||||
case other(Any)
|
||||
public struct IOData {
|
||||
private let storage: _IOData
|
||||
public init<T>(_ value: T) {
|
||||
self.storage = _IOData(value)
|
||||
}
|
||||
|
||||
public func tryAsByteBuffer() -> ByteBuffer? {
|
||||
if case .byteBuffer(let bb) = self {
|
||||
enum _IOData {
|
||||
case byteBuffer(ByteBuffer)
|
||||
case other(Any)
|
||||
|
||||
init<T>(_ value: T) {
|
||||
if T.self == ByteBuffer.self {
|
||||
self = .byteBuffer(value as! ByteBuffer)
|
||||
} else {
|
||||
self = .other(value)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func tryAsByteBuffer() -> ByteBuffer? {
|
||||
if case .byteBuffer(let bb) = self.storage {
|
||||
return bb
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
public func forceAsByteBuffer() -> ByteBuffer {
|
||||
func forceAsByteBuffer() -> ByteBuffer {
|
||||
return tryAsByteBuffer()!
|
||||
}
|
||||
|
||||
public func tryAsOther<T>(type: T.Type = T.self) -> T? {
|
||||
if case .other(let any) = self {
|
||||
func tryAsOther<T>(type: T.Type = T.self) -> T? {
|
||||
if case .other(let any) = self.storage {
|
||||
return any as? T
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
public func forceAsOther<T>(type: T.Type = T.self) -> T {
|
||||
func forceAsOther<T>(type: T.Type = T.self) -> T {
|
||||
return tryAsOther(type: type)!
|
||||
}
|
||||
|
||||
func forceAs<T>(type: T.Type = T.self) -> T {
|
||||
if T.self == ByteBuffer.self {
|
||||
return self.forceAsByteBuffer() as! T
|
||||
} else {
|
||||
return self.forceAsOther(type: type)
|
||||
}
|
||||
}
|
||||
|
||||
func tryAs<T>(type: T.Type = T.self) -> T? {
|
||||
if T.self == ByteBuffer.self {
|
||||
return self.tryAsByteBuffer() as! T?
|
||||
} else {
|
||||
return self.tryAsOther(type: type)
|
||||
}
|
||||
}
|
||||
|
||||
func asAny() -> Any {
|
||||
switch self.storage {
|
||||
case .byteBuffer(let bb):
|
||||
return bb
|
||||
case .other(let o):
|
||||
return o
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
final class PendingWrite {
|
||||
|
@ -392,7 +433,7 @@ final class SocketChannel : BaseSocketChannel<Socket> {
|
|||
readPending = false
|
||||
|
||||
assert(!closed)
|
||||
pipeline.fireChannelRead0(data: .byteBuffer(buffer))
|
||||
pipeline.fireChannelRead0(data: IOData(buffer))
|
||||
|
||||
// Reset reader and writerIndex and so allow to have the buffer filled again
|
||||
buffer.clear()
|
||||
|
@ -520,7 +561,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
|
|||
readPending = false
|
||||
|
||||
do {
|
||||
pipeline.fireChannelRead0(data: .other(try SocketChannel(socket: accepted, eventLoop: group.next() as! SelectableEventLoop)))
|
||||
pipeline.fireChannelRead0(data: IOData(try SocketChannel(socket: accepted, eventLoop: group.next() as! SelectableEventLoop)))
|
||||
} catch let err {
|
||||
let _ = try? accepted.close()
|
||||
throw err
|
||||
|
|
|
@ -19,7 +19,7 @@ public protocol ChannelHandler : class {
|
|||
func handlerRemoved(ctx: ChannelHandlerContext) throws
|
||||
}
|
||||
|
||||
public protocol ChannelOutboundHandler : ChannelHandler {
|
||||
public protocol _ChannelOutboundHandler : ChannelHandler {
|
||||
func register(ctx: ChannelHandlerContext, promise: Promise<Void>?)
|
||||
func bind(ctx: ChannelHandlerContext, to: SocketAddress, promise: Promise<Void>?)
|
||||
func connect(ctx: ChannelHandlerContext, to: SocketAddress, promise: Promise<Void>?)
|
||||
|
@ -31,7 +31,7 @@ public protocol ChannelOutboundHandler : ChannelHandler {
|
|||
func triggerUserOutboundEvent(ctx: ChannelHandlerContext, event: Any, promise: Promise<Void>?)
|
||||
}
|
||||
|
||||
public protocol ChannelInboundHandler : ChannelHandler {
|
||||
public protocol _ChannelInboundHandler : ChannelHandler {
|
||||
func channelRegistered(ctx: ChannelHandlerContext) throws
|
||||
func channelUnregistered(ctx: ChannelHandlerContext) throws
|
||||
func channelActive(ctx: ChannelHandlerContext) throws
|
||||
|
@ -55,7 +55,7 @@ public extension ChannelHandler {
|
|||
}
|
||||
}
|
||||
|
||||
public extension ChannelOutboundHandler {
|
||||
public extension _ChannelOutboundHandler {
|
||||
|
||||
public func register(ctx: ChannelHandlerContext, promise: Promise<Void>?) {
|
||||
ctx.register(promise: promise)
|
||||
|
@ -91,7 +91,7 @@ public extension ChannelOutboundHandler {
|
|||
}
|
||||
|
||||
|
||||
public extension ChannelInboundHandler {
|
||||
public extension _ChannelInboundHandler {
|
||||
|
||||
public func channelRegistered(ctx: ChannelHandlerContext) {
|
||||
ctx.fireChannelRegistered()
|
||||
|
|
|
@ -21,7 +21,11 @@ import Foundation
|
|||
ChannelHandler implementation which enforces back-pressure by stop reading from the remote-peer when it can not write back fast-enough and start reading again
|
||||
once pending data was written.
|
||||
*/
|
||||
public class BackPressureHandler: ChannelInboundHandler, ChannelOutboundHandler {
|
||||
public class BackPressureHandler: ChannelInboundHandler, _ChannelOutboundHandler {
|
||||
public typealias InboundIn = ByteBuffer
|
||||
public typealias InboundOut = ByteBuffer
|
||||
public typealias OutboundOut = ByteBuffer
|
||||
|
||||
private enum PendingRead {
|
||||
case none
|
||||
case promise(promise: Promise<Void>?)
|
||||
|
@ -75,7 +79,7 @@ public class BackPressureHandler: ChannelInboundHandler, ChannelOutboundHandler
|
|||
}
|
||||
}
|
||||
|
||||
public class ChannelInitializer: ChannelInboundHandler {
|
||||
public class ChannelInitializer: _ChannelInboundHandler {
|
||||
private let initChannel: (Channel) -> (Future<Void>)
|
||||
|
||||
public init(initChannel: @escaping (Channel) -> (Future<Void>)) {
|
||||
|
|
|
@ -56,11 +56,11 @@ public final class ChannelPipeline : ChannelInvoker {
|
|||
let ctx = ChannelHandlerContext(name: name ?? nextName(), handler: handler, pipeline: self)
|
||||
if first {
|
||||
ctx.inboundNext = inboundChain
|
||||
if handler is ChannelInboundHandler {
|
||||
if handler is _ChannelInboundHandler {
|
||||
inboundChain = ctx
|
||||
}
|
||||
|
||||
if handler is ChannelOutboundHandler {
|
||||
if handler is _ChannelOutboundHandler {
|
||||
var c = outboundChain
|
||||
|
||||
if c!.handler === HeadChannelHandler.sharedInstance {
|
||||
|
@ -83,7 +83,7 @@ public final class ChannelPipeline : ChannelInvoker {
|
|||
|
||||
contexts.insert(ctx, at: 0)
|
||||
} else {
|
||||
if handler is ChannelInboundHandler {
|
||||
if handler is _ChannelInboundHandler {
|
||||
var c = inboundChain
|
||||
|
||||
if c!.handler === TailChannelHandler.sharedInstance {
|
||||
|
@ -105,7 +105,7 @@ public final class ChannelPipeline : ChannelInvoker {
|
|||
}
|
||||
|
||||
ctx.outboundNext = outboundChain
|
||||
if handler is ChannelOutboundHandler {
|
||||
if handler is _ChannelOutboundHandler {
|
||||
outboundChain = ctx
|
||||
}
|
||||
|
||||
|
@ -541,7 +541,7 @@ public final class ChannelPipeline : ChannelInvoker {
|
|||
}
|
||||
}
|
||||
|
||||
private final class HeadChannelHandler : ChannelOutboundHandler {
|
||||
private final class HeadChannelHandler : _ChannelOutboundHandler {
|
||||
|
||||
static let sharedInstance = HeadChannelHandler()
|
||||
|
||||
|
@ -580,7 +580,7 @@ private final class HeadChannelHandler : ChannelOutboundHandler {
|
|||
}
|
||||
}
|
||||
|
||||
private final class TailChannelHandler : ChannelInboundHandler {
|
||||
private final class TailChannelHandler : _ChannelInboundHandler {
|
||||
|
||||
static let sharedInstance = TailChannelHandler()
|
||||
|
||||
|
@ -733,7 +733,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
assert(inEventLoop)
|
||||
|
||||
do {
|
||||
try (handler as! ChannelInboundHandler).channelRegistered(ctx: self)
|
||||
try (handler as! _ChannelInboundHandler).channelRegistered(ctx: self)
|
||||
} catch let err {
|
||||
invokeErrorCaught(error: err)
|
||||
}
|
||||
|
@ -743,7 +743,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
assert(inEventLoop)
|
||||
|
||||
do {
|
||||
try (handler as! ChannelInboundHandler).channelUnregistered(ctx: self)
|
||||
try (handler as! _ChannelInboundHandler).channelUnregistered(ctx: self)
|
||||
} catch let err {
|
||||
invokeErrorCaught(error: err)
|
||||
}
|
||||
|
@ -753,7 +753,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
assert(inEventLoop)
|
||||
|
||||
do {
|
||||
try (handler as! ChannelInboundHandler).channelActive(ctx: self)
|
||||
try (handler as! _ChannelInboundHandler).channelActive(ctx: self)
|
||||
} catch let err {
|
||||
invokeErrorCaught(error: err)
|
||||
}
|
||||
|
@ -763,7 +763,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
assert(inEventLoop)
|
||||
|
||||
do {
|
||||
try (handler as! ChannelInboundHandler).channelInactive(ctx: self)
|
||||
try (handler as! _ChannelInboundHandler).channelInactive(ctx: self)
|
||||
} catch let err {
|
||||
invokeErrorCaught(error: err)
|
||||
}
|
||||
|
@ -773,7 +773,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
assert(inEventLoop)
|
||||
|
||||
do {
|
||||
try (handler as! ChannelInboundHandler).channelRead(ctx: self, data: data)
|
||||
try (handler as! _ChannelInboundHandler).channelRead(ctx: self, data: data)
|
||||
} catch let err {
|
||||
invokeErrorCaught(error: err)
|
||||
}
|
||||
|
@ -783,7 +783,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
assert(inEventLoop)
|
||||
|
||||
do {
|
||||
try (handler as! ChannelInboundHandler).channelReadComplete(ctx: self)
|
||||
try (handler as! _ChannelInboundHandler).channelReadComplete(ctx: self)
|
||||
} catch let err {
|
||||
invokeErrorCaught(error: err)
|
||||
}
|
||||
|
@ -793,7 +793,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
assert(inEventLoop)
|
||||
|
||||
do {
|
||||
try (handler as! ChannelInboundHandler).channelWritabilityChanged(ctx: self)
|
||||
try (handler as! _ChannelInboundHandler).channelWritabilityChanged(ctx: self)
|
||||
} catch let err {
|
||||
invokeErrorCaught(error: err)
|
||||
}
|
||||
|
@ -803,7 +803,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
assert(inEventLoop)
|
||||
|
||||
do {
|
||||
try (handler as! ChannelInboundHandler).errorCaught(ctx: self, error: error)
|
||||
try (handler as! _ChannelInboundHandler).errorCaught(ctx: self, error: error)
|
||||
} catch let err {
|
||||
// Forward the error thrown by errorCaught through the pipeline
|
||||
fireErrorCaught(error: err)
|
||||
|
@ -814,7 +814,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
assert(inEventLoop)
|
||||
|
||||
do {
|
||||
try (handler as! ChannelInboundHandler).userInboundEventTriggered(ctx: self, event: event)
|
||||
try (handler as! _ChannelInboundHandler).userInboundEventTriggered(ctx: self, event: event)
|
||||
} catch let err {
|
||||
invokeErrorCaught(error: err)
|
||||
}
|
||||
|
@ -824,34 +824,34 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
assert(inEventLoop)
|
||||
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
|
||||
|
||||
(handler as! ChannelOutboundHandler).register(ctx: self, promise: promise)
|
||||
(handler as! _ChannelOutboundHandler).register(ctx: self, promise: promise)
|
||||
}
|
||||
|
||||
func invokeBind(to address: SocketAddress, promise: Promise<Void>?) {
|
||||
assert(inEventLoop)
|
||||
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
|
||||
|
||||
(handler as! ChannelOutboundHandler).bind(ctx: self, to: address, promise: promise)
|
||||
(handler as! _ChannelOutboundHandler).bind(ctx: self, to: address, promise: promise)
|
||||
}
|
||||
|
||||
func invokeConnect(to address: SocketAddress, promise: Promise<Void>?) {
|
||||
assert(inEventLoop)
|
||||
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
|
||||
|
||||
(handler as! ChannelOutboundHandler).connect(ctx: self, to: address, promise: promise)
|
||||
(handler as! _ChannelOutboundHandler).connect(ctx: self, to: address, promise: promise)
|
||||
}
|
||||
|
||||
func invokeWrite(data: IOData, promise: Promise<Void>?) {
|
||||
assert(inEventLoop)
|
||||
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
|
||||
|
||||
(handler as! ChannelOutboundHandler).write(ctx: self, data: data, promise: promise)
|
||||
(handler as! _ChannelOutboundHandler).write(ctx: self, data: data, promise: promise)
|
||||
}
|
||||
|
||||
func invokeFlush(promise: Promise<Void>?) {
|
||||
assert(inEventLoop)
|
||||
|
||||
(handler as! ChannelOutboundHandler).flush(ctx: self, promise: promise)
|
||||
(handler as! _ChannelOutboundHandler).flush(ctx: self, promise: promise)
|
||||
}
|
||||
|
||||
func invokeWriteAndFlush(data: IOData, promise: Promise<Void>?) {
|
||||
|
@ -876,35 +876,35 @@ public final class ChannelHandlerContext : ChannelInvoker {
|
|||
let writePromise: Promise<Void> = eventLoop.newPromise()
|
||||
let flushPromise: Promise<Void> = eventLoop.newPromise()
|
||||
|
||||
(handler as! ChannelOutboundHandler).write(ctx: self, data: data, promise: writePromise)
|
||||
(handler as! ChannelOutboundHandler).flush(ctx: self, promise: flushPromise)
|
||||
(handler as! _ChannelOutboundHandler).write(ctx: self, data: data, promise: writePromise)
|
||||
(handler as! _ChannelOutboundHandler).flush(ctx: self, promise: flushPromise)
|
||||
|
||||
writePromise.futureResult.whenComplete(callback: callback)
|
||||
flushPromise.futureResult.whenComplete(callback: callback)
|
||||
} else {
|
||||
(handler as! ChannelOutboundHandler).write(ctx: self, data: data, promise: nil)
|
||||
(handler as! ChannelOutboundHandler).flush(ctx: self, promise: nil)
|
||||
(handler as! _ChannelOutboundHandler).write(ctx: self, data: data, promise: nil)
|
||||
(handler as! _ChannelOutboundHandler).flush(ctx: self, promise: nil)
|
||||
}
|
||||
}
|
||||
|
||||
func invokeRead(promise: Promise<Void>?) {
|
||||
assert(inEventLoop)
|
||||
|
||||
(handler as! ChannelOutboundHandler).read(ctx: self, promise: promise)
|
||||
(handler as! _ChannelOutboundHandler).read(ctx: self, promise: promise)
|
||||
}
|
||||
|
||||
func invokeClose(promise: Promise<Void>?) {
|
||||
assert(inEventLoop)
|
||||
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
|
||||
|
||||
(handler as! ChannelOutboundHandler).close(ctx: self, promise: promise)
|
||||
(handler as! _ChannelOutboundHandler).close(ctx: self, promise: promise)
|
||||
}
|
||||
|
||||
func invokeTriggerUserOutboundEvent(event: Any, promise: Promise<Void>?) {
|
||||
assert(inEventLoop)
|
||||
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
|
||||
|
||||
(handler as! ChannelOutboundHandler).triggerUserOutboundEvent(ctx: self, event: event, promise: promise)
|
||||
(handler as! _ChannelOutboundHandler).triggerUserOutboundEvent(ctx: self, event: event, promise: promise)
|
||||
}
|
||||
|
||||
func invokeHandlerAdded() throws {
|
||||
|
|
|
@ -13,7 +13,8 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
public protocol ByteToMessageDecoder : ChannelInboundHandler {
|
||||
public protocol ByteToMessageDecoder : ChannelInboundHandler where InboundIn == ByteBuffer {
|
||||
|
||||
var cumulationBuffer: ByteBuffer? { get set }
|
||||
func decode(ctx: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> Bool
|
||||
func decodeLast(ctx: ChannelHandlerContext, buffer: inout ByteBuffer)throws -> Bool
|
||||
|
@ -24,7 +25,7 @@ public protocol ByteToMessageDecoder : ChannelInboundHandler {
|
|||
public extension ByteToMessageDecoder {
|
||||
|
||||
public func channelRead(ctx: ChannelHandlerContext, data: IOData) throws {
|
||||
var buffer = data.forceAsByteBuffer()
|
||||
var buffer = self.unwrapInboundIn(data)
|
||||
|
||||
if var cum = cumulationBuffer {
|
||||
var buf = ctx.channel!.allocator.buffer(capacity: cum.readableBytes + buffer.readableBytes)
|
||||
|
@ -62,10 +63,13 @@ public extension ByteToMessageDecoder {
|
|||
}
|
||||
|
||||
public func handlerRemoved(ctx: ChannelHandlerContext) throws {
|
||||
if let buffer = cumulationBuffer {
|
||||
ctx.fireChannelRead(data: .byteBuffer(buffer))
|
||||
cumulationBuffer = nil
|
||||
if let buffer = cumulationBuffer as? InboundOut {
|
||||
ctx.fireChannelRead(data: self.wrapInboundOut(buffer))
|
||||
} else {
|
||||
/* please note that we're dropping the partially received bytes (if any) on the floor here as we can't
|
||||
send a full message to the next handler. */
|
||||
}
|
||||
cumulationBuffer = nil
|
||||
try decoderRemoved(ctx: ctx)
|
||||
}
|
||||
|
||||
|
@ -81,17 +85,18 @@ public extension ByteToMessageDecoder {
|
|||
}
|
||||
}
|
||||
|
||||
public protocol MessageToByteEncoder : ChannelOutboundHandler {
|
||||
func encode(ctx: ChannelHandlerContext, data: IOData, out: inout ByteBuffer) throws
|
||||
func allocateOutBuffer(ctx: ChannelHandlerContext, data: IOData) throws -> ByteBuffer
|
||||
public protocol MessageToByteEncoder : ChannelOutboundHandler where OutboundOut == ByteBuffer {
|
||||
func encode(ctx: ChannelHandlerContext, data: OutboundIn, out: inout ByteBuffer) throws
|
||||
func allocateOutBuffer(ctx: ChannelHandlerContext, data: OutboundIn) throws -> ByteBuffer
|
||||
}
|
||||
|
||||
public extension MessageToByteEncoder {
|
||||
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
|
||||
do {
|
||||
let data = self.unwrapOutboundIn(data)
|
||||
var buffer: ByteBuffer = try allocateOutBuffer(ctx: ctx, data: data)
|
||||
try encode(ctx: ctx, data: data, out: &buffer)
|
||||
ctx.write(data: .byteBuffer(buffer), promise: promise)
|
||||
ctx.write(data: self.wrapOutboundOut(buffer), promise: promise)
|
||||
} catch let err {
|
||||
promise?.fail(error: err)
|
||||
}
|
||||
|
|
|
@ -133,12 +133,7 @@ class EmbeddedChannelCore : ChannelCore {
|
|||
}
|
||||
|
||||
private func addToBuffer(buffer: inout [Any], data: IOData) {
|
||||
switch data {
|
||||
case .byteBuffer(let buf):
|
||||
buffer.append(buf)
|
||||
case .other(let other):
|
||||
buffer.append(other)
|
||||
}
|
||||
buffer.append(data.asAny())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
public protocol ChannelInboundHandler: _ChannelInboundHandler {
|
||||
associatedtype InboundIn
|
||||
associatedtype InboundUserEventIn = Never
|
||||
|
||||
associatedtype OutboundOut = Never
|
||||
associatedtype InboundOut = Never
|
||||
associatedtype OutboundUserEventOut = Never
|
||||
associatedtype InboundUserEventOut = Never
|
||||
|
||||
func unwrapInboundIn(_ value: IOData) -> InboundIn
|
||||
func tryUnwrapInboundIn(_ value: IOData) -> InboundIn?
|
||||
func wrapInboundOut(_ value: InboundOut) -> IOData
|
||||
func unwrapInboundUserEventIn(_ value: Any) -> InboundUserEventIn
|
||||
func tryUnwrapInboundUserEventIn(_ value: Any) -> InboundUserEventIn?
|
||||
func wrapInboundUserEventOut(_ value: InboundUserEventOut) -> Any
|
||||
func wrapOutboundOut(_ value: OutboundOut) -> IOData
|
||||
}
|
||||
|
||||
public extension ChannelInboundHandler {
|
||||
func unwrapInboundIn(_ value: IOData) -> InboundIn {
|
||||
return value.forceAs()
|
||||
}
|
||||
|
||||
func tryUnwrapInboundIn(_ value: IOData) -> InboundIn? {
|
||||
return value.forceAs()
|
||||
}
|
||||
|
||||
func wrapInboundOut(_ value: InboundOut) -> IOData {
|
||||
return IOData(value)
|
||||
}
|
||||
|
||||
func unwrapInboundUserEventIn(_ value: Any) -> InboundUserEventIn {
|
||||
return value as! InboundUserEventIn
|
||||
}
|
||||
|
||||
func tryUnwrapInboundUserEventIn(_ value: Any) -> InboundUserEventIn? {
|
||||
return value as? InboundUserEventIn
|
||||
}
|
||||
|
||||
func wrapInboundUserEventOut(_ value: InboundUserEventOut) -> Any {
|
||||
return value
|
||||
}
|
||||
|
||||
func wrapOutboundOut(_ value: OutboundOut) -> IOData {
|
||||
return IOData(value)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public protocol ChannelOutboundHandler: _ChannelOutboundHandler {
|
||||
associatedtype OutboundIn
|
||||
associatedtype OutboundUserEventIn = Never
|
||||
|
||||
associatedtype OutboundOut
|
||||
associatedtype InboundOut = Never
|
||||
associatedtype OutboundUserEventOut = Never
|
||||
associatedtype InboundUserEventOut = Never
|
||||
|
||||
func unwrapOutboundIn(_ value: IOData) -> OutboundIn
|
||||
func tryUnwrapOutboundIn(_ value: IOData) -> OutboundIn?
|
||||
func wrapOutboundOut(_ value: OutboundOut) -> IOData
|
||||
|
||||
func unwrapOutboundUserEventIn(_ value: Any) -> OutboundUserEventIn
|
||||
func tryUnwrapOutboundUserEventIn(_ value: Any) -> OutboundUserEventIn?
|
||||
func wrapOutboundUserEventOut(_ value: OutboundUserEventOut) -> Any
|
||||
}
|
||||
|
||||
public extension ChannelOutboundHandler {
|
||||
func unwrapOutboundIn(_ value: IOData) -> OutboundIn {
|
||||
return value.forceAs()
|
||||
}
|
||||
|
||||
func tryUnwrapOutboundIn(_ value: IOData) -> OutboundIn? {
|
||||
return value.tryAs()
|
||||
}
|
||||
|
||||
func wrapOutboundOut(_ value: OutboundOut) -> IOData {
|
||||
return IOData(value)
|
||||
}
|
||||
|
||||
func unwrapOutboundUserEventIn(_ value: Any) -> OutboundUserEventIn {
|
||||
return value as! OutboundUserEventIn
|
||||
}
|
||||
|
||||
func tryUnwrapOutboundUserEventIn(_ value: Any) -> OutboundUserEventIn? {
|
||||
return value as? OutboundUserEventIn
|
||||
}
|
||||
|
||||
func wrapOutboundUserEventOut(_ value: OutboundUserEventOut) -> Any {
|
||||
return value
|
||||
}
|
||||
}
|
|
@ -16,6 +16,8 @@ import NIO
|
|||
|
||||
|
||||
private final class EchoHandler: ChannelInboundHandler {
|
||||
public typealias InboundIn = ByteBuffer
|
||||
public typealias OutboundOut = ByteBuffer
|
||||
|
||||
public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
|
||||
// As we are not really interested getting notified on success or failure we just pass nil as promise to
|
||||
|
|
|
@ -16,33 +16,32 @@ import Foundation
|
|||
import NIO
|
||||
|
||||
public final class HTTPResponseEncoder : ChannelOutboundHandler {
|
||||
public typealias OutboundIn = HTTPResponse
|
||||
public typealias OutboundOut = ByteBuffer
|
||||
|
||||
public init() { }
|
||||
public init() {}
|
||||
|
||||
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
|
||||
if let response:HTTPResponseHead = data.tryAsOther() {
|
||||
switch self.tryUnwrapOutboundIn(data) {
|
||||
case .some(.head(let response)):
|
||||
// TODO: Is 256 really a good value here ?
|
||||
var buffer = ctx.channel!.allocator.buffer(capacity: 256)
|
||||
response.version.write(buffer: &buffer)
|
||||
response.status.write(buffer: &buffer)
|
||||
response.headers.write(buffer: &buffer)
|
||||
|
||||
ctx.write(data: .byteBuffer(buffer), promise: promise)
|
||||
} else if let content: HTTPBodyContent = data.tryAsOther() {
|
||||
// TODO: Implement chunked encoding
|
||||
switch content {
|
||||
case .more(let buffer):
|
||||
ctx.write(data: .byteBuffer(buffer), promise: promise)
|
||||
case .last(let buffer):
|
||||
if let buf = buffer {
|
||||
ctx.write(data: .byteBuffer(buf), promise: promise)
|
||||
} else if promise != nil {
|
||||
// We only need to pass the promise further if the user is even interested in the result.
|
||||
// Empty content so just write an empty buffer
|
||||
ctx.write(data: .byteBuffer(ctx.channel!.allocator.buffer(capacity: 0)), promise: promise)
|
||||
}
|
||||
ctx.write(data: self.wrapOutboundOut(buffer), promise: promise)
|
||||
case .some(.body(.more(let buffer))):
|
||||
ctx.write(data: self.wrapOutboundOut(buffer), promise: promise)
|
||||
case .some(.body(.last(let buffer))):
|
||||
if let buf = buffer {
|
||||
ctx.write(data: self.wrapOutboundOut(buf), promise: promise)
|
||||
} else if promise != nil {
|
||||
// We only need to pass the promise further if the user is even interested in the result.
|
||||
// Empty content so just write an empty buffer
|
||||
ctx.write(data: self.wrapOutboundOut(ctx.channel!.allocator.buffer(capacity: 0)), promise: promise)
|
||||
}
|
||||
} else {
|
||||
case .none:
|
||||
ctx.write(data: data, promise: promise)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,10 @@ import Foundation
|
|||
import NIO
|
||||
import CHTTPParser
|
||||
|
||||
public final class HTTPRequestDecoder : ByteToMessageDecoder {
|
||||
public final class HTTPRequestDecoder : ChannelInboundHandler, ByteToMessageDecoder {
|
||||
public typealias InboundIn = ByteBuffer
|
||||
public typealias InboundOut = HTTPRequest
|
||||
|
||||
var parser: UnsafeMutablePointer<http_parser>?
|
||||
var settings: UnsafeMutablePointer<http_parser_settings>?
|
||||
public var cumulationBuffer: ByteBuffer?
|
||||
|
@ -127,7 +130,7 @@ public final class HTTPRequestDecoder : ByteToMessageDecoder {
|
|||
|
||||
handler.state.dataAwaitingState = .body
|
||||
|
||||
ctx.fireChannelRead(data: .other(HTTPRequest.head(request)))
|
||||
ctx.fireChannelRead(data: handler.wrapInboundOut(HTTPRequest.head(request)))
|
||||
return 0
|
||||
}
|
||||
|
||||
|
@ -138,7 +141,7 @@ public final class HTTPRequestDecoder : ByteToMessageDecoder {
|
|||
|
||||
// This will never return nil as we allocated the buffer with the correct size
|
||||
handler.state.parserBuffer.write(int8Data: data!, len: len)
|
||||
ctx.fireChannelRead(data: .other(HTTPRequest.body(HTTPBodyContent.more(buffer: handler.state.parserBuffer.readSlice(length: len)!))))
|
||||
ctx.fireChannelRead(data: handler.wrapInboundOut(HTTPRequest.body(HTTPBodyContent.more(buffer: handler.state.parserBuffer.readSlice(length: len)!))))
|
||||
|
||||
return 0
|
||||
}
|
||||
|
@ -178,7 +181,7 @@ public final class HTTPRequestDecoder : ByteToMessageDecoder {
|
|||
let ctx = evacuateContext(parser)
|
||||
let handler = ctx.handler as! HTTPRequestDecoder
|
||||
|
||||
ctx.fireChannelRead(data: .other(HTTPRequest.body(.last(buffer: nil))))
|
||||
ctx.fireChannelRead(data: handler.wrapInboundOut(HTTPRequest.body(.last(buffer: nil))))
|
||||
handler.complete(state: handler.state.dataAwaitingState)
|
||||
handler.state.dataAwaitingState = .messageBegin
|
||||
return 0
|
||||
|
|
|
@ -65,6 +65,11 @@ public extension HTTPRequestHead {
|
|||
}
|
||||
}
|
||||
|
||||
public enum HTTPResponse {
|
||||
case head(HTTPResponseHead)
|
||||
case body(HTTPBodyContent)
|
||||
}
|
||||
|
||||
public struct HTTPResponseHead {
|
||||
public let status: HTTPResponseStatus
|
||||
public let version: HTTPVersion
|
||||
|
|
|
@ -16,29 +16,33 @@ import NIO
|
|||
import NIOHTTP1
|
||||
|
||||
private class HTTPHandler : ChannelInboundHandler {
|
||||
public typealias InboundIn = HTTPRequest
|
||||
public typealias OutboundOut = HTTPResponse
|
||||
|
||||
private var buffer: ByteBuffer? = nil
|
||||
private var keepAlive = false
|
||||
|
||||
func channelRead(ctx: ChannelHandlerContext, data: IOData) throws {
|
||||
if let reqPart = data.tryAsOther(type: HTTPRequest.self) {
|
||||
func channelRead(ctx: ChannelHandlerContext, data: IOData) {
|
||||
|
||||
if let reqPart = self.tryUnwrapInboundIn(data) {
|
||||
switch reqPart {
|
||||
case .head(let request):
|
||||
keepAlive = request.isKeepAlive
|
||||
|
||||
var response = HTTPResponseHead(version: request.version, status: HTTPResponseStatus.ok)
|
||||
response.headers.add(name: "content-length", value: "12")
|
||||
ctx.write(data: .other(response), promise: nil)
|
||||
var responseHead = HTTPResponseHead(version: request.version, status: HTTPResponseStatus.ok)
|
||||
responseHead.headers.add(name: "content-length", value: "12")
|
||||
let response = HTTPResponse.head(responseHead)
|
||||
ctx.write(data: self.wrapOutboundOut(response), promise: nil)
|
||||
case .body(let content):
|
||||
switch content {
|
||||
case .more(_):
|
||||
break
|
||||
case .last:
|
||||
let content = HTTPBodyContent.last(buffer: buffer!.slice())
|
||||
let content = HTTPResponse.body(HTTPBodyContent.last(buffer: buffer!.slice()))
|
||||
if keepAlive {
|
||||
ctx.write(data: .other(content), promise: nil)
|
||||
ctx.write(data: self.wrapOutboundOut(content), promise: nil)
|
||||
} else {
|
||||
ctx.write(data: .other(content)).whenComplete(callback: { _ in
|
||||
ctx.write(data: self.wrapOutboundOut(content)).whenComplete(callback: { _ in
|
||||
ctx.close(promise: nil)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -17,15 +17,17 @@ import XCTest
|
|||
@testable import NIOHTTP1
|
||||
|
||||
private final class TestChannelInboundHandler: ChannelInboundHandler {
|
||||
public typealias InboundIn = HTTPRequest
|
||||
public typealias InboundOut = HTTPRequest
|
||||
|
||||
private let fn: (IOData) -> IOData
|
||||
private let fn: (HTTPRequest) -> HTTPRequest
|
||||
|
||||
init(_ fn: @escaping (IOData) -> IOData) {
|
||||
init(_ fn: @escaping (HTTPRequest) -> HTTPRequest) {
|
||||
self.fn = fn
|
||||
}
|
||||
|
||||
public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
|
||||
ctx.fireChannelRead(data: self.fn(data))
|
||||
ctx.fireChannelRead(data: self.wrapInboundOut(self.fn(self.unwrapInboundIn(data))))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,38 +64,34 @@ class HTTPTest: XCTestCase {
|
|||
try channel.pipeline.add(handler: HTTPRequestDecoder()).wait()
|
||||
var bodyData: Data? = nil
|
||||
var allBodyDatas: [Data] = []
|
||||
try channel.pipeline.add(handler: TestChannelInboundHandler { data in
|
||||
if let reqPart = data.tryAsOther(type: NIOHTTP1.HTTPRequest.self) {
|
||||
switch reqPart {
|
||||
case .head(var req):
|
||||
XCTAssertEqual((index * 2), step)
|
||||
req.headers.remove(name: "Content-Length")
|
||||
XCTAssertEqual(expecteds[index], req)
|
||||
step += 1
|
||||
case .body(let chunk):
|
||||
switch chunk {
|
||||
case .more(var buffer):
|
||||
try channel.pipeline.add(handler: TestChannelInboundHandler { reqPart in
|
||||
switch reqPart {
|
||||
case .head(var req):
|
||||
XCTAssertEqual((index * 2), step)
|
||||
req.headers.remove(name: "Content-Length")
|
||||
XCTAssertEqual(expecteds[index], req)
|
||||
step += 1
|
||||
case .body(let chunk):
|
||||
switch chunk {
|
||||
case .more(var buffer):
|
||||
if bodyData == nil {
|
||||
bodyData = buffer.readData(length: buffer.readableBytes)!
|
||||
} else {
|
||||
bodyData!.append(buffer.readData(length: buffer.readableBytes)!)
|
||||
}
|
||||
case .last(let buffer):
|
||||
if var buffer = buffer {
|
||||
if bodyData == nil {
|
||||
bodyData = buffer.readData(length: buffer.readableBytes)!
|
||||
} else {
|
||||
bodyData!.append(buffer.readData(length: buffer.readableBytes)!)
|
||||
}
|
||||
case .last(let buffer):
|
||||
if var buffer = buffer {
|
||||
if bodyData == nil {
|
||||
bodyData = buffer.readData(length: buffer.readableBytes)!
|
||||
} else {
|
||||
bodyData!.append(buffer.readData(length: buffer.readableBytes)!)
|
||||
}
|
||||
}
|
||||
step += 1
|
||||
XCTAssertEqual(((index + 1) * 2), step)
|
||||
}
|
||||
step += 1
|
||||
XCTAssertEqual(((index + 1) * 2), step)
|
||||
}
|
||||
} else {
|
||||
XCTFail("wrong type \(data)")
|
||||
}
|
||||
return data
|
||||
return reqPart
|
||||
}).wait()
|
||||
|
||||
for expected in expecteds {
|
||||
|
@ -124,7 +122,7 @@ class HTTPTest: XCTestCase {
|
|||
let bd1 = try sendAndCheckRequests(expecteds, body: body, sendStrategy: { (reqString, chan) in
|
||||
var buf = chan.allocator.buffer(capacity: 1024)
|
||||
buf.write(string: reqString)
|
||||
chan.pipeline.fireChannelRead(data: .byteBuffer(buf))
|
||||
chan.pipeline.fireChannelRead(data: IOData(buf))
|
||||
})
|
||||
|
||||
/* send the bytes one by one */
|
||||
|
@ -133,7 +131,7 @@ class HTTPTest: XCTestCase {
|
|||
var buf = chan.allocator.buffer(capacity: 1024)
|
||||
|
||||
buf.write(string: "\(c)")
|
||||
chan.pipeline.fireChannelRead(data: .byteBuffer(buf))
|
||||
chan.pipeline.fireChannelRead(data: IOData(buf))
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -58,35 +58,37 @@ class ChannelPipelineTest: XCTestCase {
|
|||
var buf = channel.allocator.buffer(capacity: 1024)
|
||||
buf.write(string: "hello")
|
||||
|
||||
_ = try channel.pipeline.add(handler: TestChannelOutboundHandler({ data in
|
||||
XCTAssertEqual(1, data.forceAsOther())
|
||||
return .byteBuffer(buf)
|
||||
_ = try channel.pipeline.add(handler: TestChannelOutboundHandler<Int, ByteBuffer>({ data in
|
||||
XCTAssertEqual(1, data)
|
||||
return buf
|
||||
})).wait()
|
||||
|
||||
_ = try channel.pipeline.add(handler: TestChannelOutboundHandler({ data in
|
||||
XCTAssertEqual("msg", data.forceAsOther())
|
||||
return .other(1)
|
||||
_ = try channel.pipeline.add(handler: TestChannelOutboundHandler<String, Int>({ data in
|
||||
XCTAssertEqual("msg", data)
|
||||
return 1
|
||||
})).wait()
|
||||
|
||||
|
||||
_ = channel.write(data: .other("msg"))
|
||||
_ = channel.write(data: IOData("msg"))
|
||||
_ = try channel.flush().wait()
|
||||
XCTAssertEqual(buf, channel.readOutbound())
|
||||
XCTAssertNil(channel.readOutbound())
|
||||
|
||||
}
|
||||
|
||||
private final class TestChannelOutboundHandler: ChannelOutboundHandler {
|
||||
private final class TestChannelOutboundHandler<In, Out>: ChannelOutboundHandler {
|
||||
typealias OutboundIn = In
|
||||
typealias OutboundOut = Out
|
||||
|
||||
private let fn: (IOData) throws -> IOData
|
||||
private let fn: (OutboundIn) throws -> OutboundOut
|
||||
|
||||
init(_ fn: @escaping (IOData) throws -> IOData) {
|
||||
init(_ fn: @escaping (OutboundIn) throws -> OutboundOut) {
|
||||
self.fn = fn
|
||||
}
|
||||
|
||||
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
|
||||
do {
|
||||
ctx.write(data: try fn(data), promise: promise)
|
||||
ctx.write(data: self.wrapOutboundOut(try fn(self.unwrapOutboundIn(data))), promise: promise)
|
||||
} catch let err {
|
||||
promise!.fail(error: err)
|
||||
}
|
||||
|
|
|
@ -18,6 +18,9 @@ import XCTest
|
|||
|
||||
public class ByteToMessageDecoderTest: XCTestCase {
|
||||
private final class ByteToInt32Decoder : ByteToMessageDecoder {
|
||||
typealias InboundIn = ByteBuffer
|
||||
typealias InboundOut = Int32
|
||||
|
||||
var cumulationBuffer: ByteBuffer?
|
||||
|
||||
|
||||
|
@ -25,7 +28,7 @@ public class ByteToMessageDecoderTest: XCTestCase {
|
|||
guard buffer.readableBytes >= MemoryLayout<Int32>.size else {
|
||||
return false
|
||||
}
|
||||
ctx.fireChannelRead(data: .other(buffer.readInteger()! as Int32))
|
||||
ctx.fireChannelRead(data: self.wrapInboundOut(buffer.readInteger()!))
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -40,15 +43,15 @@ public class ByteToMessageDecoderTest: XCTestCase {
|
|||
let writerIndex = buffer.writerIndex
|
||||
buffer.moveWriterIndex(to: writerIndex - 1)
|
||||
|
||||
channel.pipeline.fireChannelRead(data: .byteBuffer(buffer))
|
||||
channel.pipeline.fireChannelRead(data: IOData(buffer))
|
||||
XCTAssertNil(channel.readInbound())
|
||||
|
||||
channel.pipeline.fireChannelRead(data: .byteBuffer(buffer.slice(at: writerIndex - 1, length: 1)!))
|
||||
channel.pipeline.fireChannelRead(data: IOData(buffer.slice(at: writerIndex - 1, length: 1)!))
|
||||
|
||||
var buffer2 = channel.allocator.buffer(capacity: 32)
|
||||
buffer2.write(integer: Int32(2))
|
||||
buffer2.write(integer: Int32(3))
|
||||
channel.pipeline.fireChannelRead(data: .byteBuffer(buffer2))
|
||||
channel.pipeline.fireChannelRead(data: IOData(buffer2))
|
||||
|
||||
try channel.close().wait()
|
||||
|
||||
|
@ -62,13 +65,15 @@ public class ByteToMessageDecoderTest: XCTestCase {
|
|||
public class MessageToByteEncoderTest: XCTestCase {
|
||||
|
||||
private final class Int32ToByteEncoder : MessageToByteEncoder {
|
||||
public func encode(ctx: ChannelHandlerContext, data: IOData, out: inout ByteBuffer) throws {
|
||||
typealias OutboundIn = Int32
|
||||
typealias OutboundOut = ByteBuffer
|
||||
|
||||
public func encode(ctx: ChannelHandlerContext, data value: Int32, out: inout ByteBuffer) throws {
|
||||
XCTAssertEqual(MemoryLayout<Int32>.size, out.writableBytes)
|
||||
let value: Int32 = data.forceAsOther()
|
||||
out.write(integer: value);
|
||||
}
|
||||
|
||||
public func allocateOutBuffer(ctx: ChannelHandlerContext, data: IOData) throws -> ByteBuffer {
|
||||
public func allocateOutBuffer(ctx: ChannelHandlerContext, data: Int32) throws -> ByteBuffer {
|
||||
return ctx.channel!.allocator.buffer(capacity: MemoryLayout<Int32>.size)
|
||||
}
|
||||
}
|
||||
|
@ -78,8 +83,7 @@ public class MessageToByteEncoderTest: XCTestCase {
|
|||
|
||||
_ = try channel.pipeline.add(handler: Int32ToByteEncoder()).wait()
|
||||
|
||||
_ = try channel.writeAndFlush(data: .other(Int32(5))).wait()
|
||||
|
||||
_ = try channel.writeAndFlush(data: IOData(Int32(5))).wait()
|
||||
|
||||
var buffer = channel.readOutbound() as ByteBuffer?
|
||||
XCTAssertEqual(Int32(5), buffer?.readInteger())
|
||||
|
|
|
@ -22,7 +22,7 @@ class EmbeddedChannelTest: XCTestCase {
|
|||
buf.write(string: "hello")
|
||||
|
||||
|
||||
let f = channel.write(data: .byteBuffer(buf))
|
||||
let f = channel.write(data: IOData(buf))
|
||||
|
||||
var ranBlock = false
|
||||
f.whenSuccess { () -> Void in
|
||||
|
|
Loading…
Reference in New Issue