baby steps towards a type-safe Channel pipeline

This commit is contained in:
Johannes Weiß 2017-06-30 13:04:19 +01:00
parent 15677b9fc2
commit c09178ceb2
17 changed files with 308 additions and 139 deletions

View File

@ -128,6 +128,7 @@ public final class ServerBootstrap {
} }
private class AcceptHandler : ChannelInboundHandler { private class AcceptHandler : ChannelInboundHandler {
public typealias InboundIn = SocketChannel
private let childHandler: ChannelHandler? private let childHandler: ChannelHandler?
private let childOptions: ChannelOptionStorage private let childOptions: ChannelOptionStorage
@ -138,7 +139,7 @@ public final class ServerBootstrap {
} }
func channelRead(ctx: ChannelHandlerContext, data: IOData) { func channelRead(ctx: ChannelHandlerContext, data: IOData) {
let accepted = data.forceAsOther() as SocketChannel let accepted = self.unwrapInboundIn(data)
do { do {
try self.childOptions.applyAll(channel: accepted) try self.childOptions.applyAll(channel: accepted)

View File

@ -21,33 +21,74 @@ import Glibc
import Darwin import Darwin
#endif #endif
public enum IOData { public struct IOData {
private let storage: _IOData
public init<T>(_ value: T) {
self.storage = _IOData(value)
}
enum _IOData {
case byteBuffer(ByteBuffer) case byteBuffer(ByteBuffer)
case other(Any) case other(Any)
public func tryAsByteBuffer() -> ByteBuffer? { init<T>(_ value: T) {
if case .byteBuffer(let bb) = self { if T.self == ByteBuffer.self {
self = .byteBuffer(value as! ByteBuffer)
} else {
self = .other(value)
}
}
}
func tryAsByteBuffer() -> ByteBuffer? {
if case .byteBuffer(let bb) = self.storage {
return bb return bb
} else { } else {
return nil return nil
} }
} }
public func forceAsByteBuffer() -> ByteBuffer { func forceAsByteBuffer() -> ByteBuffer {
return tryAsByteBuffer()! return tryAsByteBuffer()!
} }
public func tryAsOther<T>(type: T.Type = T.self) -> T? { func tryAsOther<T>(type: T.Type = T.self) -> T? {
if case .other(let any) = self { if case .other(let any) = self.storage {
return any as? T return any as? T
} else { } else {
return nil return nil
} }
} }
public func forceAsOther<T>(type: T.Type = T.self) -> T { func forceAsOther<T>(type: T.Type = T.self) -> T {
return tryAsOther(type: type)! return tryAsOther(type: type)!
} }
func forceAs<T>(type: T.Type = T.self) -> T {
if T.self == ByteBuffer.self {
return self.forceAsByteBuffer() as! T
} else {
return self.forceAsOther(type: type)
}
}
func tryAs<T>(type: T.Type = T.self) -> T? {
if T.self == ByteBuffer.self {
return self.tryAsByteBuffer() as! T?
} else {
return self.tryAsOther(type: type)
}
}
func asAny() -> Any {
switch self.storage {
case .byteBuffer(let bb):
return bb
case .other(let o):
return o
}
}
} }
final class PendingWrite { final class PendingWrite {
@ -392,7 +433,7 @@ final class SocketChannel : BaseSocketChannel<Socket> {
readPending = false readPending = false
assert(!closed) assert(!closed)
pipeline.fireChannelRead0(data: .byteBuffer(buffer)) pipeline.fireChannelRead0(data: IOData(buffer))
// Reset reader and writerIndex and so allow to have the buffer filled again // Reset reader and writerIndex and so allow to have the buffer filled again
buffer.clear() buffer.clear()
@ -520,7 +561,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
readPending = false readPending = false
do { do {
pipeline.fireChannelRead0(data: .other(try SocketChannel(socket: accepted, eventLoop: group.next() as! SelectableEventLoop))) pipeline.fireChannelRead0(data: IOData(try SocketChannel(socket: accepted, eventLoop: group.next() as! SelectableEventLoop)))
} catch let err { } catch let err {
let _ = try? accepted.close() let _ = try? accepted.close()
throw err throw err

View File

@ -19,7 +19,7 @@ public protocol ChannelHandler : class {
func handlerRemoved(ctx: ChannelHandlerContext) throws func handlerRemoved(ctx: ChannelHandlerContext) throws
} }
public protocol ChannelOutboundHandler : ChannelHandler { public protocol _ChannelOutboundHandler : ChannelHandler {
func register(ctx: ChannelHandlerContext, promise: Promise<Void>?) func register(ctx: ChannelHandlerContext, promise: Promise<Void>?)
func bind(ctx: ChannelHandlerContext, to: SocketAddress, promise: Promise<Void>?) func bind(ctx: ChannelHandlerContext, to: SocketAddress, promise: Promise<Void>?)
func connect(ctx: ChannelHandlerContext, to: SocketAddress, promise: Promise<Void>?) func connect(ctx: ChannelHandlerContext, to: SocketAddress, promise: Promise<Void>?)
@ -31,7 +31,7 @@ public protocol ChannelOutboundHandler : ChannelHandler {
func triggerUserOutboundEvent(ctx: ChannelHandlerContext, event: Any, promise: Promise<Void>?) func triggerUserOutboundEvent(ctx: ChannelHandlerContext, event: Any, promise: Promise<Void>?)
} }
public protocol ChannelInboundHandler : ChannelHandler { public protocol _ChannelInboundHandler : ChannelHandler {
func channelRegistered(ctx: ChannelHandlerContext) throws func channelRegistered(ctx: ChannelHandlerContext) throws
func channelUnregistered(ctx: ChannelHandlerContext) throws func channelUnregistered(ctx: ChannelHandlerContext) throws
func channelActive(ctx: ChannelHandlerContext) throws func channelActive(ctx: ChannelHandlerContext) throws
@ -55,7 +55,7 @@ public extension ChannelHandler {
} }
} }
public extension ChannelOutboundHandler { public extension _ChannelOutboundHandler {
public func register(ctx: ChannelHandlerContext, promise: Promise<Void>?) { public func register(ctx: ChannelHandlerContext, promise: Promise<Void>?) {
ctx.register(promise: promise) ctx.register(promise: promise)
@ -91,7 +91,7 @@ public extension ChannelOutboundHandler {
} }
public extension ChannelInboundHandler { public extension _ChannelInboundHandler {
public func channelRegistered(ctx: ChannelHandlerContext) { public func channelRegistered(ctx: ChannelHandlerContext) {
ctx.fireChannelRegistered() ctx.fireChannelRegistered()

View File

@ -21,7 +21,11 @@ import Foundation
ChannelHandler implementation which enforces back-pressure by stop reading from the remote-peer when it can not write back fast-enough and start reading again ChannelHandler implementation which enforces back-pressure by stop reading from the remote-peer when it can not write back fast-enough and start reading again
once pending data was written. once pending data was written.
*/ */
public class BackPressureHandler: ChannelInboundHandler, ChannelOutboundHandler { public class BackPressureHandler: ChannelInboundHandler, _ChannelOutboundHandler {
public typealias InboundIn = ByteBuffer
public typealias InboundOut = ByteBuffer
public typealias OutboundOut = ByteBuffer
private enum PendingRead { private enum PendingRead {
case none case none
case promise(promise: Promise<Void>?) case promise(promise: Promise<Void>?)
@ -75,7 +79,7 @@ public class BackPressureHandler: ChannelInboundHandler, ChannelOutboundHandler
} }
} }
public class ChannelInitializer: ChannelInboundHandler { public class ChannelInitializer: _ChannelInboundHandler {
private let initChannel: (Channel) -> (Future<Void>) private let initChannel: (Channel) -> (Future<Void>)
public init(initChannel: @escaping (Channel) -> (Future<Void>)) { public init(initChannel: @escaping (Channel) -> (Future<Void>)) {

View File

@ -56,11 +56,11 @@ public final class ChannelPipeline : ChannelInvoker {
let ctx = ChannelHandlerContext(name: name ?? nextName(), handler: handler, pipeline: self) let ctx = ChannelHandlerContext(name: name ?? nextName(), handler: handler, pipeline: self)
if first { if first {
ctx.inboundNext = inboundChain ctx.inboundNext = inboundChain
if handler is ChannelInboundHandler { if handler is _ChannelInboundHandler {
inboundChain = ctx inboundChain = ctx
} }
if handler is ChannelOutboundHandler { if handler is _ChannelOutboundHandler {
var c = outboundChain var c = outboundChain
if c!.handler === HeadChannelHandler.sharedInstance { if c!.handler === HeadChannelHandler.sharedInstance {
@ -83,7 +83,7 @@ public final class ChannelPipeline : ChannelInvoker {
contexts.insert(ctx, at: 0) contexts.insert(ctx, at: 0)
} else { } else {
if handler is ChannelInboundHandler { if handler is _ChannelInboundHandler {
var c = inboundChain var c = inboundChain
if c!.handler === TailChannelHandler.sharedInstance { if c!.handler === TailChannelHandler.sharedInstance {
@ -105,7 +105,7 @@ public final class ChannelPipeline : ChannelInvoker {
} }
ctx.outboundNext = outboundChain ctx.outboundNext = outboundChain
if handler is ChannelOutboundHandler { if handler is _ChannelOutboundHandler {
outboundChain = ctx outboundChain = ctx
} }
@ -541,7 +541,7 @@ public final class ChannelPipeline : ChannelInvoker {
} }
} }
private final class HeadChannelHandler : ChannelOutboundHandler { private final class HeadChannelHandler : _ChannelOutboundHandler {
static let sharedInstance = HeadChannelHandler() static let sharedInstance = HeadChannelHandler()
@ -580,7 +580,7 @@ private final class HeadChannelHandler : ChannelOutboundHandler {
} }
} }
private final class TailChannelHandler : ChannelInboundHandler { private final class TailChannelHandler : _ChannelInboundHandler {
static let sharedInstance = TailChannelHandler() static let sharedInstance = TailChannelHandler()
@ -733,7 +733,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop) assert(inEventLoop)
do { do {
try (handler as! ChannelInboundHandler).channelRegistered(ctx: self) try (handler as! _ChannelInboundHandler).channelRegistered(ctx: self)
} catch let err { } catch let err {
invokeErrorCaught(error: err) invokeErrorCaught(error: err)
} }
@ -743,7 +743,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop) assert(inEventLoop)
do { do {
try (handler as! ChannelInboundHandler).channelUnregistered(ctx: self) try (handler as! _ChannelInboundHandler).channelUnregistered(ctx: self)
} catch let err { } catch let err {
invokeErrorCaught(error: err) invokeErrorCaught(error: err)
} }
@ -753,7 +753,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop) assert(inEventLoop)
do { do {
try (handler as! ChannelInboundHandler).channelActive(ctx: self) try (handler as! _ChannelInboundHandler).channelActive(ctx: self)
} catch let err { } catch let err {
invokeErrorCaught(error: err) invokeErrorCaught(error: err)
} }
@ -763,7 +763,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop) assert(inEventLoop)
do { do {
try (handler as! ChannelInboundHandler).channelInactive(ctx: self) try (handler as! _ChannelInboundHandler).channelInactive(ctx: self)
} catch let err { } catch let err {
invokeErrorCaught(error: err) invokeErrorCaught(error: err)
} }
@ -773,7 +773,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop) assert(inEventLoop)
do { do {
try (handler as! ChannelInboundHandler).channelRead(ctx: self, data: data) try (handler as! _ChannelInboundHandler).channelRead(ctx: self, data: data)
} catch let err { } catch let err {
invokeErrorCaught(error: err) invokeErrorCaught(error: err)
} }
@ -783,7 +783,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop) assert(inEventLoop)
do { do {
try (handler as! ChannelInboundHandler).channelReadComplete(ctx: self) try (handler as! _ChannelInboundHandler).channelReadComplete(ctx: self)
} catch let err { } catch let err {
invokeErrorCaught(error: err) invokeErrorCaught(error: err)
} }
@ -793,7 +793,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop) assert(inEventLoop)
do { do {
try (handler as! ChannelInboundHandler).channelWritabilityChanged(ctx: self) try (handler as! _ChannelInboundHandler).channelWritabilityChanged(ctx: self)
} catch let err { } catch let err {
invokeErrorCaught(error: err) invokeErrorCaught(error: err)
} }
@ -803,7 +803,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop) assert(inEventLoop)
do { do {
try (handler as! ChannelInboundHandler).errorCaught(ctx: self, error: error) try (handler as! _ChannelInboundHandler).errorCaught(ctx: self, error: error)
} catch let err { } catch let err {
// Forward the error thrown by errorCaught through the pipeline // Forward the error thrown by errorCaught through the pipeline
fireErrorCaught(error: err) fireErrorCaught(error: err)
@ -814,7 +814,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop) assert(inEventLoop)
do { do {
try (handler as! ChannelInboundHandler).userInboundEventTriggered(ctx: self, event: event) try (handler as! _ChannelInboundHandler).userInboundEventTriggered(ctx: self, event: event)
} catch let err { } catch let err {
invokeErrorCaught(error: err) invokeErrorCaught(error: err)
} }
@ -824,34 +824,34 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop) assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled") assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
(handler as! ChannelOutboundHandler).register(ctx: self, promise: promise) (handler as! _ChannelOutboundHandler).register(ctx: self, promise: promise)
} }
func invokeBind(to address: SocketAddress, promise: Promise<Void>?) { func invokeBind(to address: SocketAddress, promise: Promise<Void>?) {
assert(inEventLoop) assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled") assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
(handler as! ChannelOutboundHandler).bind(ctx: self, to: address, promise: promise) (handler as! _ChannelOutboundHandler).bind(ctx: self, to: address, promise: promise)
} }
func invokeConnect(to address: SocketAddress, promise: Promise<Void>?) { func invokeConnect(to address: SocketAddress, promise: Promise<Void>?) {
assert(inEventLoop) assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled") assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
(handler as! ChannelOutboundHandler).connect(ctx: self, to: address, promise: promise) (handler as! _ChannelOutboundHandler).connect(ctx: self, to: address, promise: promise)
} }
func invokeWrite(data: IOData, promise: Promise<Void>?) { func invokeWrite(data: IOData, promise: Promise<Void>?) {
assert(inEventLoop) assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled") assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
(handler as! ChannelOutboundHandler).write(ctx: self, data: data, promise: promise) (handler as! _ChannelOutboundHandler).write(ctx: self, data: data, promise: promise)
} }
func invokeFlush(promise: Promise<Void>?) { func invokeFlush(promise: Promise<Void>?) {
assert(inEventLoop) assert(inEventLoop)
(handler as! ChannelOutboundHandler).flush(ctx: self, promise: promise) (handler as! _ChannelOutboundHandler).flush(ctx: self, promise: promise)
} }
func invokeWriteAndFlush(data: IOData, promise: Promise<Void>?) { func invokeWriteAndFlush(data: IOData, promise: Promise<Void>?) {
@ -876,35 +876,35 @@ public final class ChannelHandlerContext : ChannelInvoker {
let writePromise: Promise<Void> = eventLoop.newPromise() let writePromise: Promise<Void> = eventLoop.newPromise()
let flushPromise: Promise<Void> = eventLoop.newPromise() let flushPromise: Promise<Void> = eventLoop.newPromise()
(handler as! ChannelOutboundHandler).write(ctx: self, data: data, promise: writePromise) (handler as! _ChannelOutboundHandler).write(ctx: self, data: data, promise: writePromise)
(handler as! ChannelOutboundHandler).flush(ctx: self, promise: flushPromise) (handler as! _ChannelOutboundHandler).flush(ctx: self, promise: flushPromise)
writePromise.futureResult.whenComplete(callback: callback) writePromise.futureResult.whenComplete(callback: callback)
flushPromise.futureResult.whenComplete(callback: callback) flushPromise.futureResult.whenComplete(callback: callback)
} else { } else {
(handler as! ChannelOutboundHandler).write(ctx: self, data: data, promise: nil) (handler as! _ChannelOutboundHandler).write(ctx: self, data: data, promise: nil)
(handler as! ChannelOutboundHandler).flush(ctx: self, promise: nil) (handler as! _ChannelOutboundHandler).flush(ctx: self, promise: nil)
} }
} }
func invokeRead(promise: Promise<Void>?) { func invokeRead(promise: Promise<Void>?) {
assert(inEventLoop) assert(inEventLoop)
(handler as! ChannelOutboundHandler).read(ctx: self, promise: promise) (handler as! _ChannelOutboundHandler).read(ctx: self, promise: promise)
} }
func invokeClose(promise: Promise<Void>?) { func invokeClose(promise: Promise<Void>?) {
assert(inEventLoop) assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled") assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
(handler as! ChannelOutboundHandler).close(ctx: self, promise: promise) (handler as! _ChannelOutboundHandler).close(ctx: self, promise: promise)
} }
func invokeTriggerUserOutboundEvent(event: Any, promise: Promise<Void>?) { func invokeTriggerUserOutboundEvent(event: Any, promise: Promise<Void>?) {
assert(inEventLoop) assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled") assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
(handler as! ChannelOutboundHandler).triggerUserOutboundEvent(ctx: self, event: event, promise: promise) (handler as! _ChannelOutboundHandler).triggerUserOutboundEvent(ctx: self, event: event, promise: promise)
} }
func invokeHandlerAdded() throws { func invokeHandlerAdded() throws {

View File

@ -13,7 +13,8 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
public protocol ByteToMessageDecoder : ChannelInboundHandler { public protocol ByteToMessageDecoder : ChannelInboundHandler where InboundIn == ByteBuffer {
var cumulationBuffer: ByteBuffer? { get set } var cumulationBuffer: ByteBuffer? { get set }
func decode(ctx: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> Bool func decode(ctx: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> Bool
func decodeLast(ctx: ChannelHandlerContext, buffer: inout ByteBuffer)throws -> Bool func decodeLast(ctx: ChannelHandlerContext, buffer: inout ByteBuffer)throws -> Bool
@ -24,7 +25,7 @@ public protocol ByteToMessageDecoder : ChannelInboundHandler {
public extension ByteToMessageDecoder { public extension ByteToMessageDecoder {
public func channelRead(ctx: ChannelHandlerContext, data: IOData) throws { public func channelRead(ctx: ChannelHandlerContext, data: IOData) throws {
var buffer = data.forceAsByteBuffer() var buffer = self.unwrapInboundIn(data)
if var cum = cumulationBuffer { if var cum = cumulationBuffer {
var buf = ctx.channel!.allocator.buffer(capacity: cum.readableBytes + buffer.readableBytes) var buf = ctx.channel!.allocator.buffer(capacity: cum.readableBytes + buffer.readableBytes)
@ -62,10 +63,13 @@ public extension ByteToMessageDecoder {
} }
public func handlerRemoved(ctx: ChannelHandlerContext) throws { public func handlerRemoved(ctx: ChannelHandlerContext) throws {
if let buffer = cumulationBuffer { if let buffer = cumulationBuffer as? InboundOut {
ctx.fireChannelRead(data: .byteBuffer(buffer)) ctx.fireChannelRead(data: self.wrapInboundOut(buffer))
cumulationBuffer = nil } else {
/* please note that we're dropping the partially received bytes (if any) on the floor here as we can't
send a full message to the next handler. */
} }
cumulationBuffer = nil
try decoderRemoved(ctx: ctx) try decoderRemoved(ctx: ctx)
} }
@ -81,17 +85,18 @@ public extension ByteToMessageDecoder {
} }
} }
public protocol MessageToByteEncoder : ChannelOutboundHandler { public protocol MessageToByteEncoder : ChannelOutboundHandler where OutboundOut == ByteBuffer {
func encode(ctx: ChannelHandlerContext, data: IOData, out: inout ByteBuffer) throws func encode(ctx: ChannelHandlerContext, data: OutboundIn, out: inout ByteBuffer) throws
func allocateOutBuffer(ctx: ChannelHandlerContext, data: IOData) throws -> ByteBuffer func allocateOutBuffer(ctx: ChannelHandlerContext, data: OutboundIn) throws -> ByteBuffer
} }
public extension MessageToByteEncoder { public extension MessageToByteEncoder {
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) { public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
do { do {
let data = self.unwrapOutboundIn(data)
var buffer: ByteBuffer = try allocateOutBuffer(ctx: ctx, data: data) var buffer: ByteBuffer = try allocateOutBuffer(ctx: ctx, data: data)
try encode(ctx: ctx, data: data, out: &buffer) try encode(ctx: ctx, data: data, out: &buffer)
ctx.write(data: .byteBuffer(buffer), promise: promise) ctx.write(data: self.wrapOutboundOut(buffer), promise: promise)
} catch let err { } catch let err {
promise?.fail(error: err) promise?.fail(error: err)
} }

View File

@ -133,12 +133,7 @@ class EmbeddedChannelCore : ChannelCore {
} }
private func addToBuffer(buffer: inout [Any], data: IOData) { private func addToBuffer(buffer: inout [Any], data: IOData) {
switch data { buffer.append(data.asAny())
case .byteBuffer(let buf):
buffer.append(buf)
case .other(let other):
buffer.append(other)
}
} }
} }

View File

@ -0,0 +1,106 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
public protocol ChannelInboundHandler: _ChannelInboundHandler {
associatedtype InboundIn
associatedtype InboundUserEventIn = Never
associatedtype OutboundOut = Never
associatedtype InboundOut = Never
associatedtype OutboundUserEventOut = Never
associatedtype InboundUserEventOut = Never
func unwrapInboundIn(_ value: IOData) -> InboundIn
func tryUnwrapInboundIn(_ value: IOData) -> InboundIn?
func wrapInboundOut(_ value: InboundOut) -> IOData
func unwrapInboundUserEventIn(_ value: Any) -> InboundUserEventIn
func tryUnwrapInboundUserEventIn(_ value: Any) -> InboundUserEventIn?
func wrapInboundUserEventOut(_ value: InboundUserEventOut) -> Any
func wrapOutboundOut(_ value: OutboundOut) -> IOData
}
public extension ChannelInboundHandler {
func unwrapInboundIn(_ value: IOData) -> InboundIn {
return value.forceAs()
}
func tryUnwrapInboundIn(_ value: IOData) -> InboundIn? {
return value.forceAs()
}
func wrapInboundOut(_ value: InboundOut) -> IOData {
return IOData(value)
}
func unwrapInboundUserEventIn(_ value: Any) -> InboundUserEventIn {
return value as! InboundUserEventIn
}
func tryUnwrapInboundUserEventIn(_ value: Any) -> InboundUserEventIn? {
return value as? InboundUserEventIn
}
func wrapInboundUserEventOut(_ value: InboundUserEventOut) -> Any {
return value
}
func wrapOutboundOut(_ value: OutboundOut) -> IOData {
return IOData(value)
}
}
public protocol ChannelOutboundHandler: _ChannelOutboundHandler {
associatedtype OutboundIn
associatedtype OutboundUserEventIn = Never
associatedtype OutboundOut
associatedtype InboundOut = Never
associatedtype OutboundUserEventOut = Never
associatedtype InboundUserEventOut = Never
func unwrapOutboundIn(_ value: IOData) -> OutboundIn
func tryUnwrapOutboundIn(_ value: IOData) -> OutboundIn?
func wrapOutboundOut(_ value: OutboundOut) -> IOData
func unwrapOutboundUserEventIn(_ value: Any) -> OutboundUserEventIn
func tryUnwrapOutboundUserEventIn(_ value: Any) -> OutboundUserEventIn?
func wrapOutboundUserEventOut(_ value: OutboundUserEventOut) -> Any
}
public extension ChannelOutboundHandler {
func unwrapOutboundIn(_ value: IOData) -> OutboundIn {
return value.forceAs()
}
func tryUnwrapOutboundIn(_ value: IOData) -> OutboundIn? {
return value.tryAs()
}
func wrapOutboundOut(_ value: OutboundOut) -> IOData {
return IOData(value)
}
func unwrapOutboundUserEventIn(_ value: Any) -> OutboundUserEventIn {
return value as! OutboundUserEventIn
}
func tryUnwrapOutboundUserEventIn(_ value: Any) -> OutboundUserEventIn? {
return value as? OutboundUserEventIn
}
func wrapOutboundUserEventOut(_ value: OutboundUserEventOut) -> Any {
return value
}
}

View File

@ -16,6 +16,8 @@ import NIO
private final class EchoHandler: ChannelInboundHandler { private final class EchoHandler: ChannelInboundHandler {
public typealias InboundIn = ByteBuffer
public typealias OutboundOut = ByteBuffer
public func channelRead(ctx: ChannelHandlerContext, data: IOData) { public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
// As we are not really interested getting notified on success or failure we just pass nil as promise to // As we are not really interested getting notified on success or failure we just pass nil as promise to

View File

@ -16,33 +16,32 @@ import Foundation
import NIO import NIO
public final class HTTPResponseEncoder : ChannelOutboundHandler { public final class HTTPResponseEncoder : ChannelOutboundHandler {
public typealias OutboundIn = HTTPResponse
public typealias OutboundOut = ByteBuffer
public init() { } public init() {}
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) { public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
if let response:HTTPResponseHead = data.tryAsOther() { switch self.tryUnwrapOutboundIn(data) {
case .some(.head(let response)):
// TODO: Is 256 really a good value here ? // TODO: Is 256 really a good value here ?
var buffer = ctx.channel!.allocator.buffer(capacity: 256) var buffer = ctx.channel!.allocator.buffer(capacity: 256)
response.version.write(buffer: &buffer) response.version.write(buffer: &buffer)
response.status.write(buffer: &buffer) response.status.write(buffer: &buffer)
response.headers.write(buffer: &buffer) response.headers.write(buffer: &buffer)
ctx.write(data: .byteBuffer(buffer), promise: promise) ctx.write(data: self.wrapOutboundOut(buffer), promise: promise)
} else if let content: HTTPBodyContent = data.tryAsOther() { case .some(.body(.more(let buffer))):
// TODO: Implement chunked encoding ctx.write(data: self.wrapOutboundOut(buffer), promise: promise)
switch content { case .some(.body(.last(let buffer))):
case .more(let buffer):
ctx.write(data: .byteBuffer(buffer), promise: promise)
case .last(let buffer):
if let buf = buffer { if let buf = buffer {
ctx.write(data: .byteBuffer(buf), promise: promise) ctx.write(data: self.wrapOutboundOut(buf), promise: promise)
} else if promise != nil { } else if promise != nil {
// We only need to pass the promise further if the user is even interested in the result. // We only need to pass the promise further if the user is even interested in the result.
// Empty content so just write an empty buffer // Empty content so just write an empty buffer
ctx.write(data: .byteBuffer(ctx.channel!.allocator.buffer(capacity: 0)), promise: promise) ctx.write(data: self.wrapOutboundOut(ctx.channel!.allocator.buffer(capacity: 0)), promise: promise)
} }
} case .none:
} else {
ctx.write(data: data, promise: promise) ctx.write(data: data, promise: promise)
} }
} }

View File

@ -16,7 +16,10 @@ import Foundation
import NIO import NIO
import CHTTPParser import CHTTPParser
public final class HTTPRequestDecoder : ByteToMessageDecoder { public final class HTTPRequestDecoder : ChannelInboundHandler, ByteToMessageDecoder {
public typealias InboundIn = ByteBuffer
public typealias InboundOut = HTTPRequest
var parser: UnsafeMutablePointer<http_parser>? var parser: UnsafeMutablePointer<http_parser>?
var settings: UnsafeMutablePointer<http_parser_settings>? var settings: UnsafeMutablePointer<http_parser_settings>?
public var cumulationBuffer: ByteBuffer? public var cumulationBuffer: ByteBuffer?
@ -127,7 +130,7 @@ public final class HTTPRequestDecoder : ByteToMessageDecoder {
handler.state.dataAwaitingState = .body handler.state.dataAwaitingState = .body
ctx.fireChannelRead(data: .other(HTTPRequest.head(request))) ctx.fireChannelRead(data: handler.wrapInboundOut(HTTPRequest.head(request)))
return 0 return 0
} }
@ -138,7 +141,7 @@ public final class HTTPRequestDecoder : ByteToMessageDecoder {
// This will never return nil as we allocated the buffer with the correct size // This will never return nil as we allocated the buffer with the correct size
handler.state.parserBuffer.write(int8Data: data!, len: len) handler.state.parserBuffer.write(int8Data: data!, len: len)
ctx.fireChannelRead(data: .other(HTTPRequest.body(HTTPBodyContent.more(buffer: handler.state.parserBuffer.readSlice(length: len)!)))) ctx.fireChannelRead(data: handler.wrapInboundOut(HTTPRequest.body(HTTPBodyContent.more(buffer: handler.state.parserBuffer.readSlice(length: len)!))))
return 0 return 0
} }
@ -178,7 +181,7 @@ public final class HTTPRequestDecoder : ByteToMessageDecoder {
let ctx = evacuateContext(parser) let ctx = evacuateContext(parser)
let handler = ctx.handler as! HTTPRequestDecoder let handler = ctx.handler as! HTTPRequestDecoder
ctx.fireChannelRead(data: .other(HTTPRequest.body(.last(buffer: nil)))) ctx.fireChannelRead(data: handler.wrapInboundOut(HTTPRequest.body(.last(buffer: nil))))
handler.complete(state: handler.state.dataAwaitingState) handler.complete(state: handler.state.dataAwaitingState)
handler.state.dataAwaitingState = .messageBegin handler.state.dataAwaitingState = .messageBegin
return 0 return 0

View File

@ -65,6 +65,11 @@ public extension HTTPRequestHead {
} }
} }
public enum HTTPResponse {
case head(HTTPResponseHead)
case body(HTTPBodyContent)
}
public struct HTTPResponseHead { public struct HTTPResponseHead {
public let status: HTTPResponseStatus public let status: HTTPResponseStatus
public let version: HTTPVersion public let version: HTTPVersion

View File

@ -16,29 +16,33 @@ import NIO
import NIOHTTP1 import NIOHTTP1
private class HTTPHandler : ChannelInboundHandler { private class HTTPHandler : ChannelInboundHandler {
public typealias InboundIn = HTTPRequest
public typealias OutboundOut = HTTPResponse
private var buffer: ByteBuffer? = nil private var buffer: ByteBuffer? = nil
private var keepAlive = false private var keepAlive = false
func channelRead(ctx: ChannelHandlerContext, data: IOData) throws { func channelRead(ctx: ChannelHandlerContext, data: IOData) {
if let reqPart = data.tryAsOther(type: HTTPRequest.self) {
if let reqPart = self.tryUnwrapInboundIn(data) {
switch reqPart { switch reqPart {
case .head(let request): case .head(let request):
keepAlive = request.isKeepAlive keepAlive = request.isKeepAlive
var response = HTTPResponseHead(version: request.version, status: HTTPResponseStatus.ok) var responseHead = HTTPResponseHead(version: request.version, status: HTTPResponseStatus.ok)
response.headers.add(name: "content-length", value: "12") responseHead.headers.add(name: "content-length", value: "12")
ctx.write(data: .other(response), promise: nil) let response = HTTPResponse.head(responseHead)
ctx.write(data: self.wrapOutboundOut(response), promise: nil)
case .body(let content): case .body(let content):
switch content { switch content {
case .more(_): case .more(_):
break break
case .last: case .last:
let content = HTTPBodyContent.last(buffer: buffer!.slice()) let content = HTTPResponse.body(HTTPBodyContent.last(buffer: buffer!.slice()))
if keepAlive { if keepAlive {
ctx.write(data: .other(content), promise: nil) ctx.write(data: self.wrapOutboundOut(content), promise: nil)
} else { } else {
ctx.write(data: .other(content)).whenComplete(callback: { _ in ctx.write(data: self.wrapOutboundOut(content)).whenComplete(callback: { _ in
ctx.close(promise: nil) ctx.close(promise: nil)
}) })
} }

View File

@ -17,15 +17,17 @@ import XCTest
@testable import NIOHTTP1 @testable import NIOHTTP1
private final class TestChannelInboundHandler: ChannelInboundHandler { private final class TestChannelInboundHandler: ChannelInboundHandler {
public typealias InboundIn = HTTPRequest
public typealias InboundOut = HTTPRequest
private let fn: (IOData) -> IOData private let fn: (HTTPRequest) -> HTTPRequest
init(_ fn: @escaping (IOData) -> IOData) { init(_ fn: @escaping (HTTPRequest) -> HTTPRequest) {
self.fn = fn self.fn = fn
} }
public func channelRead(ctx: ChannelHandlerContext, data: IOData) { public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
ctx.fireChannelRead(data: self.fn(data)) ctx.fireChannelRead(data: self.wrapInboundOut(self.fn(self.unwrapInboundIn(data))))
} }
} }
@ -62,8 +64,7 @@ class HTTPTest: XCTestCase {
try channel.pipeline.add(handler: HTTPRequestDecoder()).wait() try channel.pipeline.add(handler: HTTPRequestDecoder()).wait()
var bodyData: Data? = nil var bodyData: Data? = nil
var allBodyDatas: [Data] = [] var allBodyDatas: [Data] = []
try channel.pipeline.add(handler: TestChannelInboundHandler { data in try channel.pipeline.add(handler: TestChannelInboundHandler { reqPart in
if let reqPart = data.tryAsOther(type: NIOHTTP1.HTTPRequest.self) {
switch reqPart { switch reqPart {
case .head(var req): case .head(var req):
XCTAssertEqual((index * 2), step) XCTAssertEqual((index * 2), step)
@ -90,10 +91,7 @@ class HTTPTest: XCTestCase {
XCTAssertEqual(((index + 1) * 2), step) XCTAssertEqual(((index + 1) * 2), step)
} }
} }
} else { return reqPart
XCTFail("wrong type \(data)")
}
return data
}).wait() }).wait()
for expected in expecteds { for expected in expecteds {
@ -124,7 +122,7 @@ class HTTPTest: XCTestCase {
let bd1 = try sendAndCheckRequests(expecteds, body: body, sendStrategy: { (reqString, chan) in let bd1 = try sendAndCheckRequests(expecteds, body: body, sendStrategy: { (reqString, chan) in
var buf = chan.allocator.buffer(capacity: 1024) var buf = chan.allocator.buffer(capacity: 1024)
buf.write(string: reqString) buf.write(string: reqString)
chan.pipeline.fireChannelRead(data: .byteBuffer(buf)) chan.pipeline.fireChannelRead(data: IOData(buf))
}) })
/* send the bytes one by one */ /* send the bytes one by one */
@ -133,7 +131,7 @@ class HTTPTest: XCTestCase {
var buf = chan.allocator.buffer(capacity: 1024) var buf = chan.allocator.buffer(capacity: 1024)
buf.write(string: "\(c)") buf.write(string: "\(c)")
chan.pipeline.fireChannelRead(data: .byteBuffer(buf)) chan.pipeline.fireChannelRead(data: IOData(buf))
} }
}) })

View File

@ -58,35 +58,37 @@ class ChannelPipelineTest: XCTestCase {
var buf = channel.allocator.buffer(capacity: 1024) var buf = channel.allocator.buffer(capacity: 1024)
buf.write(string: "hello") buf.write(string: "hello")
_ = try channel.pipeline.add(handler: TestChannelOutboundHandler({ data in _ = try channel.pipeline.add(handler: TestChannelOutboundHandler<Int, ByteBuffer>({ data in
XCTAssertEqual(1, data.forceAsOther()) XCTAssertEqual(1, data)
return .byteBuffer(buf) return buf
})).wait() })).wait()
_ = try channel.pipeline.add(handler: TestChannelOutboundHandler({ data in _ = try channel.pipeline.add(handler: TestChannelOutboundHandler<String, Int>({ data in
XCTAssertEqual("msg", data.forceAsOther()) XCTAssertEqual("msg", data)
return .other(1) return 1
})).wait() })).wait()
_ = channel.write(data: .other("msg")) _ = channel.write(data: IOData("msg"))
_ = try channel.flush().wait() _ = try channel.flush().wait()
XCTAssertEqual(buf, channel.readOutbound()) XCTAssertEqual(buf, channel.readOutbound())
XCTAssertNil(channel.readOutbound()) XCTAssertNil(channel.readOutbound())
} }
private final class TestChannelOutboundHandler: ChannelOutboundHandler { private final class TestChannelOutboundHandler<In, Out>: ChannelOutboundHandler {
typealias OutboundIn = In
typealias OutboundOut = Out
private let fn: (IOData) throws -> IOData private let fn: (OutboundIn) throws -> OutboundOut
init(_ fn: @escaping (IOData) throws -> IOData) { init(_ fn: @escaping (OutboundIn) throws -> OutboundOut) {
self.fn = fn self.fn = fn
} }
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) { public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
do { do {
ctx.write(data: try fn(data), promise: promise) ctx.write(data: self.wrapOutboundOut(try fn(self.unwrapOutboundIn(data))), promise: promise)
} catch let err { } catch let err {
promise!.fail(error: err) promise!.fail(error: err)
} }

View File

@ -18,6 +18,9 @@ import XCTest
public class ByteToMessageDecoderTest: XCTestCase { public class ByteToMessageDecoderTest: XCTestCase {
private final class ByteToInt32Decoder : ByteToMessageDecoder { private final class ByteToInt32Decoder : ByteToMessageDecoder {
typealias InboundIn = ByteBuffer
typealias InboundOut = Int32
var cumulationBuffer: ByteBuffer? var cumulationBuffer: ByteBuffer?
@ -25,7 +28,7 @@ public class ByteToMessageDecoderTest: XCTestCase {
guard buffer.readableBytes >= MemoryLayout<Int32>.size else { guard buffer.readableBytes >= MemoryLayout<Int32>.size else {
return false return false
} }
ctx.fireChannelRead(data: .other(buffer.readInteger()! as Int32)) ctx.fireChannelRead(data: self.wrapInboundOut(buffer.readInteger()!))
return true return true
} }
} }
@ -40,15 +43,15 @@ public class ByteToMessageDecoderTest: XCTestCase {
let writerIndex = buffer.writerIndex let writerIndex = buffer.writerIndex
buffer.moveWriterIndex(to: writerIndex - 1) buffer.moveWriterIndex(to: writerIndex - 1)
channel.pipeline.fireChannelRead(data: .byteBuffer(buffer)) channel.pipeline.fireChannelRead(data: IOData(buffer))
XCTAssertNil(channel.readInbound()) XCTAssertNil(channel.readInbound())
channel.pipeline.fireChannelRead(data: .byteBuffer(buffer.slice(at: writerIndex - 1, length: 1)!)) channel.pipeline.fireChannelRead(data: IOData(buffer.slice(at: writerIndex - 1, length: 1)!))
var buffer2 = channel.allocator.buffer(capacity: 32) var buffer2 = channel.allocator.buffer(capacity: 32)
buffer2.write(integer: Int32(2)) buffer2.write(integer: Int32(2))
buffer2.write(integer: Int32(3)) buffer2.write(integer: Int32(3))
channel.pipeline.fireChannelRead(data: .byteBuffer(buffer2)) channel.pipeline.fireChannelRead(data: IOData(buffer2))
try channel.close().wait() try channel.close().wait()
@ -62,13 +65,15 @@ public class ByteToMessageDecoderTest: XCTestCase {
public class MessageToByteEncoderTest: XCTestCase { public class MessageToByteEncoderTest: XCTestCase {
private final class Int32ToByteEncoder : MessageToByteEncoder { private final class Int32ToByteEncoder : MessageToByteEncoder {
public func encode(ctx: ChannelHandlerContext, data: IOData, out: inout ByteBuffer) throws { typealias OutboundIn = Int32
typealias OutboundOut = ByteBuffer
public func encode(ctx: ChannelHandlerContext, data value: Int32, out: inout ByteBuffer) throws {
XCTAssertEqual(MemoryLayout<Int32>.size, out.writableBytes) XCTAssertEqual(MemoryLayout<Int32>.size, out.writableBytes)
let value: Int32 = data.forceAsOther()
out.write(integer: value); out.write(integer: value);
} }
public func allocateOutBuffer(ctx: ChannelHandlerContext, data: IOData) throws -> ByteBuffer { public func allocateOutBuffer(ctx: ChannelHandlerContext, data: Int32) throws -> ByteBuffer {
return ctx.channel!.allocator.buffer(capacity: MemoryLayout<Int32>.size) return ctx.channel!.allocator.buffer(capacity: MemoryLayout<Int32>.size)
} }
} }
@ -78,8 +83,7 @@ public class MessageToByteEncoderTest: XCTestCase {
_ = try channel.pipeline.add(handler: Int32ToByteEncoder()).wait() _ = try channel.pipeline.add(handler: Int32ToByteEncoder()).wait()
_ = try channel.writeAndFlush(data: .other(Int32(5))).wait() _ = try channel.writeAndFlush(data: IOData(Int32(5))).wait()
var buffer = channel.readOutbound() as ByteBuffer? var buffer = channel.readOutbound() as ByteBuffer?
XCTAssertEqual(Int32(5), buffer?.readInteger()) XCTAssertEqual(Int32(5), buffer?.readInteger())

View File

@ -22,7 +22,7 @@ class EmbeddedChannelTest: XCTestCase {
buf.write(string: "hello") buf.write(string: "hello")
let f = channel.write(data: .byteBuffer(buf)) let f = channel.write(data: IOData(buf))
var ranBlock = false var ranBlock = false
f.whenSuccess { () -> Void in f.whenSuccess { () -> Void in