baby steps towards a type-safe Channel pipeline

This commit is contained in:
Johannes Weiß 2017-06-30 13:04:19 +01:00
parent 15677b9fc2
commit c09178ceb2
17 changed files with 308 additions and 139 deletions

View File

@ -128,6 +128,7 @@ public final class ServerBootstrap {
}
private class AcceptHandler : ChannelInboundHandler {
public typealias InboundIn = SocketChannel
private let childHandler: ChannelHandler?
private let childOptions: ChannelOptionStorage
@ -138,7 +139,7 @@ public final class ServerBootstrap {
}
func channelRead(ctx: ChannelHandlerContext, data: IOData) {
let accepted = data.forceAsOther() as SocketChannel
let accepted = self.unwrapInboundIn(data)
do {
try self.childOptions.applyAll(channel: accepted)

View File

@ -21,33 +21,74 @@ import Glibc
import Darwin
#endif
public enum IOData {
case byteBuffer(ByteBuffer)
case other(Any)
public struct IOData {
private let storage: _IOData
public init<T>(_ value: T) {
self.storage = _IOData(value)
}
public func tryAsByteBuffer() -> ByteBuffer? {
if case .byteBuffer(let bb) = self {
enum _IOData {
case byteBuffer(ByteBuffer)
case other(Any)
init<T>(_ value: T) {
if T.self == ByteBuffer.self {
self = .byteBuffer(value as! ByteBuffer)
} else {
self = .other(value)
}
}
}
func tryAsByteBuffer() -> ByteBuffer? {
if case .byteBuffer(let bb) = self.storage {
return bb
} else {
return nil
}
}
public func forceAsByteBuffer() -> ByteBuffer {
func forceAsByteBuffer() -> ByteBuffer {
return tryAsByteBuffer()!
}
public func tryAsOther<T>(type: T.Type = T.self) -> T? {
if case .other(let any) = self {
func tryAsOther<T>(type: T.Type = T.self) -> T? {
if case .other(let any) = self.storage {
return any as? T
} else {
return nil
}
}
public func forceAsOther<T>(type: T.Type = T.self) -> T {
func forceAsOther<T>(type: T.Type = T.self) -> T {
return tryAsOther(type: type)!
}
func forceAs<T>(type: T.Type = T.self) -> T {
if T.self == ByteBuffer.self {
return self.forceAsByteBuffer() as! T
} else {
return self.forceAsOther(type: type)
}
}
func tryAs<T>(type: T.Type = T.self) -> T? {
if T.self == ByteBuffer.self {
return self.tryAsByteBuffer() as! T?
} else {
return self.tryAsOther(type: type)
}
}
func asAny() -> Any {
switch self.storage {
case .byteBuffer(let bb):
return bb
case .other(let o):
return o
}
}
}
final class PendingWrite {
@ -392,7 +433,7 @@ final class SocketChannel : BaseSocketChannel<Socket> {
readPending = false
assert(!closed)
pipeline.fireChannelRead0(data: .byteBuffer(buffer))
pipeline.fireChannelRead0(data: IOData(buffer))
// Reset reader and writerIndex and so allow to have the buffer filled again
buffer.clear()
@ -520,7 +561,7 @@ final class ServerSocketChannel : BaseSocketChannel<ServerSocket> {
readPending = false
do {
pipeline.fireChannelRead0(data: .other(try SocketChannel(socket: accepted, eventLoop: group.next() as! SelectableEventLoop)))
pipeline.fireChannelRead0(data: IOData(try SocketChannel(socket: accepted, eventLoop: group.next() as! SelectableEventLoop)))
} catch let err {
let _ = try? accepted.close()
throw err

View File

@ -19,7 +19,7 @@ public protocol ChannelHandler : class {
func handlerRemoved(ctx: ChannelHandlerContext) throws
}
public protocol ChannelOutboundHandler : ChannelHandler {
public protocol _ChannelOutboundHandler : ChannelHandler {
func register(ctx: ChannelHandlerContext, promise: Promise<Void>?)
func bind(ctx: ChannelHandlerContext, to: SocketAddress, promise: Promise<Void>?)
func connect(ctx: ChannelHandlerContext, to: SocketAddress, promise: Promise<Void>?)
@ -31,7 +31,7 @@ public protocol ChannelOutboundHandler : ChannelHandler {
func triggerUserOutboundEvent(ctx: ChannelHandlerContext, event: Any, promise: Promise<Void>?)
}
public protocol ChannelInboundHandler : ChannelHandler {
public protocol _ChannelInboundHandler : ChannelHandler {
func channelRegistered(ctx: ChannelHandlerContext) throws
func channelUnregistered(ctx: ChannelHandlerContext) throws
func channelActive(ctx: ChannelHandlerContext) throws
@ -55,7 +55,7 @@ public extension ChannelHandler {
}
}
public extension ChannelOutboundHandler {
public extension _ChannelOutboundHandler {
public func register(ctx: ChannelHandlerContext, promise: Promise<Void>?) {
ctx.register(promise: promise)
@ -91,7 +91,7 @@ public extension ChannelOutboundHandler {
}
public extension ChannelInboundHandler {
public extension _ChannelInboundHandler {
public func channelRegistered(ctx: ChannelHandlerContext) {
ctx.fireChannelRegistered()

View File

@ -21,7 +21,11 @@ import Foundation
ChannelHandler implementation which enforces back-pressure by stop reading from the remote-peer when it can not write back fast-enough and start reading again
once pending data was written.
*/
public class BackPressureHandler: ChannelInboundHandler, ChannelOutboundHandler {
public class BackPressureHandler: ChannelInboundHandler, _ChannelOutboundHandler {
public typealias InboundIn = ByteBuffer
public typealias InboundOut = ByteBuffer
public typealias OutboundOut = ByteBuffer
private enum PendingRead {
case none
case promise(promise: Promise<Void>?)
@ -75,7 +79,7 @@ public class BackPressureHandler: ChannelInboundHandler, ChannelOutboundHandler
}
}
public class ChannelInitializer: ChannelInboundHandler {
public class ChannelInitializer: _ChannelInboundHandler {
private let initChannel: (Channel) -> (Future<Void>)
public init(initChannel: @escaping (Channel) -> (Future<Void>)) {

View File

@ -56,11 +56,11 @@ public final class ChannelPipeline : ChannelInvoker {
let ctx = ChannelHandlerContext(name: name ?? nextName(), handler: handler, pipeline: self)
if first {
ctx.inboundNext = inboundChain
if handler is ChannelInboundHandler {
if handler is _ChannelInboundHandler {
inboundChain = ctx
}
if handler is ChannelOutboundHandler {
if handler is _ChannelOutboundHandler {
var c = outboundChain
if c!.handler === HeadChannelHandler.sharedInstance {
@ -83,7 +83,7 @@ public final class ChannelPipeline : ChannelInvoker {
contexts.insert(ctx, at: 0)
} else {
if handler is ChannelInboundHandler {
if handler is _ChannelInboundHandler {
var c = inboundChain
if c!.handler === TailChannelHandler.sharedInstance {
@ -105,7 +105,7 @@ public final class ChannelPipeline : ChannelInvoker {
}
ctx.outboundNext = outboundChain
if handler is ChannelOutboundHandler {
if handler is _ChannelOutboundHandler {
outboundChain = ctx
}
@ -541,7 +541,7 @@ public final class ChannelPipeline : ChannelInvoker {
}
}
private final class HeadChannelHandler : ChannelOutboundHandler {
private final class HeadChannelHandler : _ChannelOutboundHandler {
static let sharedInstance = HeadChannelHandler()
@ -580,7 +580,7 @@ private final class HeadChannelHandler : ChannelOutboundHandler {
}
}
private final class TailChannelHandler : ChannelInboundHandler {
private final class TailChannelHandler : _ChannelInboundHandler {
static let sharedInstance = TailChannelHandler()
@ -733,7 +733,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop)
do {
try (handler as! ChannelInboundHandler).channelRegistered(ctx: self)
try (handler as! _ChannelInboundHandler).channelRegistered(ctx: self)
} catch let err {
invokeErrorCaught(error: err)
}
@ -743,7 +743,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop)
do {
try (handler as! ChannelInboundHandler).channelUnregistered(ctx: self)
try (handler as! _ChannelInboundHandler).channelUnregistered(ctx: self)
} catch let err {
invokeErrorCaught(error: err)
}
@ -753,7 +753,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop)
do {
try (handler as! ChannelInboundHandler).channelActive(ctx: self)
try (handler as! _ChannelInboundHandler).channelActive(ctx: self)
} catch let err {
invokeErrorCaught(error: err)
}
@ -763,7 +763,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop)
do {
try (handler as! ChannelInboundHandler).channelInactive(ctx: self)
try (handler as! _ChannelInboundHandler).channelInactive(ctx: self)
} catch let err {
invokeErrorCaught(error: err)
}
@ -773,7 +773,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop)
do {
try (handler as! ChannelInboundHandler).channelRead(ctx: self, data: data)
try (handler as! _ChannelInboundHandler).channelRead(ctx: self, data: data)
} catch let err {
invokeErrorCaught(error: err)
}
@ -783,7 +783,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop)
do {
try (handler as! ChannelInboundHandler).channelReadComplete(ctx: self)
try (handler as! _ChannelInboundHandler).channelReadComplete(ctx: self)
} catch let err {
invokeErrorCaught(error: err)
}
@ -793,7 +793,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop)
do {
try (handler as! ChannelInboundHandler).channelWritabilityChanged(ctx: self)
try (handler as! _ChannelInboundHandler).channelWritabilityChanged(ctx: self)
} catch let err {
invokeErrorCaught(error: err)
}
@ -803,7 +803,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop)
do {
try (handler as! ChannelInboundHandler).errorCaught(ctx: self, error: error)
try (handler as! _ChannelInboundHandler).errorCaught(ctx: self, error: error)
} catch let err {
// Forward the error thrown by errorCaught through the pipeline
fireErrorCaught(error: err)
@ -814,7 +814,7 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop)
do {
try (handler as! ChannelInboundHandler).userInboundEventTriggered(ctx: self, event: event)
try (handler as! _ChannelInboundHandler).userInboundEventTriggered(ctx: self, event: event)
} catch let err {
invokeErrorCaught(error: err)
}
@ -824,34 +824,34 @@ public final class ChannelHandlerContext : ChannelInvoker {
assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
(handler as! ChannelOutboundHandler).register(ctx: self, promise: promise)
(handler as! _ChannelOutboundHandler).register(ctx: self, promise: promise)
}
func invokeBind(to address: SocketAddress, promise: Promise<Void>?) {
assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
(handler as! ChannelOutboundHandler).bind(ctx: self, to: address, promise: promise)
(handler as! _ChannelOutboundHandler).bind(ctx: self, to: address, promise: promise)
}
func invokeConnect(to address: SocketAddress, promise: Promise<Void>?) {
assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
(handler as! ChannelOutboundHandler).connect(ctx: self, to: address, promise: promise)
(handler as! _ChannelOutboundHandler).connect(ctx: self, to: address, promise: promise)
}
func invokeWrite(data: IOData, promise: Promise<Void>?) {
assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
(handler as! ChannelOutboundHandler).write(ctx: self, data: data, promise: promise)
(handler as! _ChannelOutboundHandler).write(ctx: self, data: data, promise: promise)
}
func invokeFlush(promise: Promise<Void>?) {
assert(inEventLoop)
(handler as! ChannelOutboundHandler).flush(ctx: self, promise: promise)
(handler as! _ChannelOutboundHandler).flush(ctx: self, promise: promise)
}
func invokeWriteAndFlush(data: IOData, promise: Promise<Void>?) {
@ -876,35 +876,35 @@ public final class ChannelHandlerContext : ChannelInvoker {
let writePromise: Promise<Void> = eventLoop.newPromise()
let flushPromise: Promise<Void> = eventLoop.newPromise()
(handler as! ChannelOutboundHandler).write(ctx: self, data: data, promise: writePromise)
(handler as! ChannelOutboundHandler).flush(ctx: self, promise: flushPromise)
(handler as! _ChannelOutboundHandler).write(ctx: self, data: data, promise: writePromise)
(handler as! _ChannelOutboundHandler).flush(ctx: self, promise: flushPromise)
writePromise.futureResult.whenComplete(callback: callback)
flushPromise.futureResult.whenComplete(callback: callback)
} else {
(handler as! ChannelOutboundHandler).write(ctx: self, data: data, promise: nil)
(handler as! ChannelOutboundHandler).flush(ctx: self, promise: nil)
(handler as! _ChannelOutboundHandler).write(ctx: self, data: data, promise: nil)
(handler as! _ChannelOutboundHandler).flush(ctx: self, promise: nil)
}
}
func invokeRead(promise: Promise<Void>?) {
assert(inEventLoop)
(handler as! ChannelOutboundHandler).read(ctx: self, promise: promise)
(handler as! _ChannelOutboundHandler).read(ctx: self, promise: promise)
}
func invokeClose(promise: Promise<Void>?) {
assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
(handler as! ChannelOutboundHandler).close(ctx: self, promise: promise)
(handler as! _ChannelOutboundHandler).close(ctx: self, promise: promise)
}
func invokeTriggerUserOutboundEvent(event: Any, promise: Promise<Void>?) {
assert(inEventLoop)
assert(promise.map { !$0.futureResult.fulfilled } ?? true, "Promise \(promise!) already fulfilled")
(handler as! ChannelOutboundHandler).triggerUserOutboundEvent(ctx: self, event: event, promise: promise)
(handler as! _ChannelOutboundHandler).triggerUserOutboundEvent(ctx: self, event: event, promise: promise)
}
func invokeHandlerAdded() throws {

View File

@ -13,7 +13,8 @@
//===----------------------------------------------------------------------===//
public protocol ByteToMessageDecoder : ChannelInboundHandler {
public protocol ByteToMessageDecoder : ChannelInboundHandler where InboundIn == ByteBuffer {
var cumulationBuffer: ByteBuffer? { get set }
func decode(ctx: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> Bool
func decodeLast(ctx: ChannelHandlerContext, buffer: inout ByteBuffer)throws -> Bool
@ -24,7 +25,7 @@ public protocol ByteToMessageDecoder : ChannelInboundHandler {
public extension ByteToMessageDecoder {
public func channelRead(ctx: ChannelHandlerContext, data: IOData) throws {
var buffer = data.forceAsByteBuffer()
var buffer = self.unwrapInboundIn(data)
if var cum = cumulationBuffer {
var buf = ctx.channel!.allocator.buffer(capacity: cum.readableBytes + buffer.readableBytes)
@ -62,10 +63,13 @@ public extension ByteToMessageDecoder {
}
public func handlerRemoved(ctx: ChannelHandlerContext) throws {
if let buffer = cumulationBuffer {
ctx.fireChannelRead(data: .byteBuffer(buffer))
cumulationBuffer = nil
if let buffer = cumulationBuffer as? InboundOut {
ctx.fireChannelRead(data: self.wrapInboundOut(buffer))
} else {
/* please note that we're dropping the partially received bytes (if any) on the floor here as we can't
send a full message to the next handler. */
}
cumulationBuffer = nil
try decoderRemoved(ctx: ctx)
}
@ -81,17 +85,18 @@ public extension ByteToMessageDecoder {
}
}
public protocol MessageToByteEncoder : ChannelOutboundHandler {
func encode(ctx: ChannelHandlerContext, data: IOData, out: inout ByteBuffer) throws
func allocateOutBuffer(ctx: ChannelHandlerContext, data: IOData) throws -> ByteBuffer
public protocol MessageToByteEncoder : ChannelOutboundHandler where OutboundOut == ByteBuffer {
func encode(ctx: ChannelHandlerContext, data: OutboundIn, out: inout ByteBuffer) throws
func allocateOutBuffer(ctx: ChannelHandlerContext, data: OutboundIn) throws -> ByteBuffer
}
public extension MessageToByteEncoder {
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
do {
let data = self.unwrapOutboundIn(data)
var buffer: ByteBuffer = try allocateOutBuffer(ctx: ctx, data: data)
try encode(ctx: ctx, data: data, out: &buffer)
ctx.write(data: .byteBuffer(buffer), promise: promise)
ctx.write(data: self.wrapOutboundOut(buffer), promise: promise)
} catch let err {
promise?.fail(error: err)
}

View File

@ -133,12 +133,7 @@ class EmbeddedChannelCore : ChannelCore {
}
private func addToBuffer(buffer: inout [Any], data: IOData) {
switch data {
case .byteBuffer(let buf):
buffer.append(buf)
case .other(let other):
buffer.append(other)
}
buffer.append(data.asAny())
}
}

View File

@ -0,0 +1,106 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
public protocol ChannelInboundHandler: _ChannelInboundHandler {
associatedtype InboundIn
associatedtype InboundUserEventIn = Never
associatedtype OutboundOut = Never
associatedtype InboundOut = Never
associatedtype OutboundUserEventOut = Never
associatedtype InboundUserEventOut = Never
func unwrapInboundIn(_ value: IOData) -> InboundIn
func tryUnwrapInboundIn(_ value: IOData) -> InboundIn?
func wrapInboundOut(_ value: InboundOut) -> IOData
func unwrapInboundUserEventIn(_ value: Any) -> InboundUserEventIn
func tryUnwrapInboundUserEventIn(_ value: Any) -> InboundUserEventIn?
func wrapInboundUserEventOut(_ value: InboundUserEventOut) -> Any
func wrapOutboundOut(_ value: OutboundOut) -> IOData
}
public extension ChannelInboundHandler {
func unwrapInboundIn(_ value: IOData) -> InboundIn {
return value.forceAs()
}
func tryUnwrapInboundIn(_ value: IOData) -> InboundIn? {
return value.forceAs()
}
func wrapInboundOut(_ value: InboundOut) -> IOData {
return IOData(value)
}
func unwrapInboundUserEventIn(_ value: Any) -> InboundUserEventIn {
return value as! InboundUserEventIn
}
func tryUnwrapInboundUserEventIn(_ value: Any) -> InboundUserEventIn? {
return value as? InboundUserEventIn
}
func wrapInboundUserEventOut(_ value: InboundUserEventOut) -> Any {
return value
}
func wrapOutboundOut(_ value: OutboundOut) -> IOData {
return IOData(value)
}
}
public protocol ChannelOutboundHandler: _ChannelOutboundHandler {
associatedtype OutboundIn
associatedtype OutboundUserEventIn = Never
associatedtype OutboundOut
associatedtype InboundOut = Never
associatedtype OutboundUserEventOut = Never
associatedtype InboundUserEventOut = Never
func unwrapOutboundIn(_ value: IOData) -> OutboundIn
func tryUnwrapOutboundIn(_ value: IOData) -> OutboundIn?
func wrapOutboundOut(_ value: OutboundOut) -> IOData
func unwrapOutboundUserEventIn(_ value: Any) -> OutboundUserEventIn
func tryUnwrapOutboundUserEventIn(_ value: Any) -> OutboundUserEventIn?
func wrapOutboundUserEventOut(_ value: OutboundUserEventOut) -> Any
}
public extension ChannelOutboundHandler {
func unwrapOutboundIn(_ value: IOData) -> OutboundIn {
return value.forceAs()
}
func tryUnwrapOutboundIn(_ value: IOData) -> OutboundIn? {
return value.tryAs()
}
func wrapOutboundOut(_ value: OutboundOut) -> IOData {
return IOData(value)
}
func unwrapOutboundUserEventIn(_ value: Any) -> OutboundUserEventIn {
return value as! OutboundUserEventIn
}
func tryUnwrapOutboundUserEventIn(_ value: Any) -> OutboundUserEventIn? {
return value as? OutboundUserEventIn
}
func wrapOutboundUserEventOut(_ value: OutboundUserEventOut) -> Any {
return value
}
}

View File

@ -16,6 +16,8 @@ import NIO
private final class EchoHandler: ChannelInboundHandler {
public typealias InboundIn = ByteBuffer
public typealias OutboundOut = ByteBuffer
public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
// As we are not really interested getting notified on success or failure we just pass nil as promise to

View File

@ -16,33 +16,32 @@ import Foundation
import NIO
public final class HTTPResponseEncoder : ChannelOutboundHandler {
public typealias OutboundIn = HTTPResponse
public typealias OutboundOut = ByteBuffer
public init() { }
public init() {}
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
if let response:HTTPResponseHead = data.tryAsOther() {
switch self.tryUnwrapOutboundIn(data) {
case .some(.head(let response)):
// TODO: Is 256 really a good value here ?
var buffer = ctx.channel!.allocator.buffer(capacity: 256)
response.version.write(buffer: &buffer)
response.status.write(buffer: &buffer)
response.headers.write(buffer: &buffer)
ctx.write(data: .byteBuffer(buffer), promise: promise)
} else if let content: HTTPBodyContent = data.tryAsOther() {
// TODO: Implement chunked encoding
switch content {
case .more(let buffer):
ctx.write(data: .byteBuffer(buffer), promise: promise)
case .last(let buffer):
if let buf = buffer {
ctx.write(data: .byteBuffer(buf), promise: promise)
} else if promise != nil {
// We only need to pass the promise further if the user is even interested in the result.
// Empty content so just write an empty buffer
ctx.write(data: .byteBuffer(ctx.channel!.allocator.buffer(capacity: 0)), promise: promise)
}
ctx.write(data: self.wrapOutboundOut(buffer), promise: promise)
case .some(.body(.more(let buffer))):
ctx.write(data: self.wrapOutboundOut(buffer), promise: promise)
case .some(.body(.last(let buffer))):
if let buf = buffer {
ctx.write(data: self.wrapOutboundOut(buf), promise: promise)
} else if promise != nil {
// We only need to pass the promise further if the user is even interested in the result.
// Empty content so just write an empty buffer
ctx.write(data: self.wrapOutboundOut(ctx.channel!.allocator.buffer(capacity: 0)), promise: promise)
}
} else {
case .none:
ctx.write(data: data, promise: promise)
}
}

View File

@ -16,7 +16,10 @@ import Foundation
import NIO
import CHTTPParser
public final class HTTPRequestDecoder : ByteToMessageDecoder {
public final class HTTPRequestDecoder : ChannelInboundHandler, ByteToMessageDecoder {
public typealias InboundIn = ByteBuffer
public typealias InboundOut = HTTPRequest
var parser: UnsafeMutablePointer<http_parser>?
var settings: UnsafeMutablePointer<http_parser_settings>?
public var cumulationBuffer: ByteBuffer?
@ -127,7 +130,7 @@ public final class HTTPRequestDecoder : ByteToMessageDecoder {
handler.state.dataAwaitingState = .body
ctx.fireChannelRead(data: .other(HTTPRequest.head(request)))
ctx.fireChannelRead(data: handler.wrapInboundOut(HTTPRequest.head(request)))
return 0
}
@ -138,7 +141,7 @@ public final class HTTPRequestDecoder : ByteToMessageDecoder {
// This will never return nil as we allocated the buffer with the correct size
handler.state.parserBuffer.write(int8Data: data!, len: len)
ctx.fireChannelRead(data: .other(HTTPRequest.body(HTTPBodyContent.more(buffer: handler.state.parserBuffer.readSlice(length: len)!))))
ctx.fireChannelRead(data: handler.wrapInboundOut(HTTPRequest.body(HTTPBodyContent.more(buffer: handler.state.parserBuffer.readSlice(length: len)!))))
return 0
}
@ -178,7 +181,7 @@ public final class HTTPRequestDecoder : ByteToMessageDecoder {
let ctx = evacuateContext(parser)
let handler = ctx.handler as! HTTPRequestDecoder
ctx.fireChannelRead(data: .other(HTTPRequest.body(.last(buffer: nil))))
ctx.fireChannelRead(data: handler.wrapInboundOut(HTTPRequest.body(.last(buffer: nil))))
handler.complete(state: handler.state.dataAwaitingState)
handler.state.dataAwaitingState = .messageBegin
return 0

View File

@ -65,6 +65,11 @@ public extension HTTPRequestHead {
}
}
public enum HTTPResponse {
case head(HTTPResponseHead)
case body(HTTPBodyContent)
}
public struct HTTPResponseHead {
public let status: HTTPResponseStatus
public let version: HTTPVersion

View File

@ -16,29 +16,33 @@ import NIO
import NIOHTTP1
private class HTTPHandler : ChannelInboundHandler {
public typealias InboundIn = HTTPRequest
public typealias OutboundOut = HTTPResponse
private var buffer: ByteBuffer? = nil
private var keepAlive = false
func channelRead(ctx: ChannelHandlerContext, data: IOData) throws {
if let reqPart = data.tryAsOther(type: HTTPRequest.self) {
func channelRead(ctx: ChannelHandlerContext, data: IOData) {
if let reqPart = self.tryUnwrapInboundIn(data) {
switch reqPart {
case .head(let request):
keepAlive = request.isKeepAlive
var response = HTTPResponseHead(version: request.version, status: HTTPResponseStatus.ok)
response.headers.add(name: "content-length", value: "12")
ctx.write(data: .other(response), promise: nil)
var responseHead = HTTPResponseHead(version: request.version, status: HTTPResponseStatus.ok)
responseHead.headers.add(name: "content-length", value: "12")
let response = HTTPResponse.head(responseHead)
ctx.write(data: self.wrapOutboundOut(response), promise: nil)
case .body(let content):
switch content {
case .more(_):
break
case .last:
let content = HTTPBodyContent.last(buffer: buffer!.slice())
let content = HTTPResponse.body(HTTPBodyContent.last(buffer: buffer!.slice()))
if keepAlive {
ctx.write(data: .other(content), promise: nil)
ctx.write(data: self.wrapOutboundOut(content), promise: nil)
} else {
ctx.write(data: .other(content)).whenComplete(callback: { _ in
ctx.write(data: self.wrapOutboundOut(content)).whenComplete(callback: { _ in
ctx.close(promise: nil)
})
}

View File

@ -17,15 +17,17 @@ import XCTest
@testable import NIOHTTP1
private final class TestChannelInboundHandler: ChannelInboundHandler {
public typealias InboundIn = HTTPRequest
public typealias InboundOut = HTTPRequest
private let fn: (IOData) -> IOData
private let fn: (HTTPRequest) -> HTTPRequest
init(_ fn: @escaping (IOData) -> IOData) {
init(_ fn: @escaping (HTTPRequest) -> HTTPRequest) {
self.fn = fn
}
public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
ctx.fireChannelRead(data: self.fn(data))
ctx.fireChannelRead(data: self.wrapInboundOut(self.fn(self.unwrapInboundIn(data))))
}
}
@ -62,38 +64,34 @@ class HTTPTest: XCTestCase {
try channel.pipeline.add(handler: HTTPRequestDecoder()).wait()
var bodyData: Data? = nil
var allBodyDatas: [Data] = []
try channel.pipeline.add(handler: TestChannelInboundHandler { data in
if let reqPart = data.tryAsOther(type: NIOHTTP1.HTTPRequest.self) {
switch reqPart {
case .head(var req):
XCTAssertEqual((index * 2), step)
req.headers.remove(name: "Content-Length")
XCTAssertEqual(expecteds[index], req)
step += 1
case .body(let chunk):
switch chunk {
case .more(var buffer):
try channel.pipeline.add(handler: TestChannelInboundHandler { reqPart in
switch reqPart {
case .head(var req):
XCTAssertEqual((index * 2), step)
req.headers.remove(name: "Content-Length")
XCTAssertEqual(expecteds[index], req)
step += 1
case .body(let chunk):
switch chunk {
case .more(var buffer):
if bodyData == nil {
bodyData = buffer.readData(length: buffer.readableBytes)!
} else {
bodyData!.append(buffer.readData(length: buffer.readableBytes)!)
}
case .last(let buffer):
if var buffer = buffer {
if bodyData == nil {
bodyData = buffer.readData(length: buffer.readableBytes)!
} else {
bodyData!.append(buffer.readData(length: buffer.readableBytes)!)
}
case .last(let buffer):
if var buffer = buffer {
if bodyData == nil {
bodyData = buffer.readData(length: buffer.readableBytes)!
} else {
bodyData!.append(buffer.readData(length: buffer.readableBytes)!)
}
}
step += 1
XCTAssertEqual(((index + 1) * 2), step)
}
step += 1
XCTAssertEqual(((index + 1) * 2), step)
}
} else {
XCTFail("wrong type \(data)")
}
return data
return reqPart
}).wait()
for expected in expecteds {
@ -124,7 +122,7 @@ class HTTPTest: XCTestCase {
let bd1 = try sendAndCheckRequests(expecteds, body: body, sendStrategy: { (reqString, chan) in
var buf = chan.allocator.buffer(capacity: 1024)
buf.write(string: reqString)
chan.pipeline.fireChannelRead(data: .byteBuffer(buf))
chan.pipeline.fireChannelRead(data: IOData(buf))
})
/* send the bytes one by one */
@ -133,7 +131,7 @@ class HTTPTest: XCTestCase {
var buf = chan.allocator.buffer(capacity: 1024)
buf.write(string: "\(c)")
chan.pipeline.fireChannelRead(data: .byteBuffer(buf))
chan.pipeline.fireChannelRead(data: IOData(buf))
}
})

View File

@ -58,35 +58,37 @@ class ChannelPipelineTest: XCTestCase {
var buf = channel.allocator.buffer(capacity: 1024)
buf.write(string: "hello")
_ = try channel.pipeline.add(handler: TestChannelOutboundHandler({ data in
XCTAssertEqual(1, data.forceAsOther())
return .byteBuffer(buf)
_ = try channel.pipeline.add(handler: TestChannelOutboundHandler<Int, ByteBuffer>({ data in
XCTAssertEqual(1, data)
return buf
})).wait()
_ = try channel.pipeline.add(handler: TestChannelOutboundHandler({ data in
XCTAssertEqual("msg", data.forceAsOther())
return .other(1)
_ = try channel.pipeline.add(handler: TestChannelOutboundHandler<String, Int>({ data in
XCTAssertEqual("msg", data)
return 1
})).wait()
_ = channel.write(data: .other("msg"))
_ = channel.write(data: IOData("msg"))
_ = try channel.flush().wait()
XCTAssertEqual(buf, channel.readOutbound())
XCTAssertNil(channel.readOutbound())
}
private final class TestChannelOutboundHandler: ChannelOutboundHandler {
private final class TestChannelOutboundHandler<In, Out>: ChannelOutboundHandler {
typealias OutboundIn = In
typealias OutboundOut = Out
private let fn: (IOData) throws -> IOData
private let fn: (OutboundIn) throws -> OutboundOut
init(_ fn: @escaping (IOData) throws -> IOData) {
init(_ fn: @escaping (OutboundIn) throws -> OutboundOut) {
self.fn = fn
}
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
do {
ctx.write(data: try fn(data), promise: promise)
ctx.write(data: self.wrapOutboundOut(try fn(self.unwrapOutboundIn(data))), promise: promise)
} catch let err {
promise!.fail(error: err)
}

View File

@ -18,6 +18,9 @@ import XCTest
public class ByteToMessageDecoderTest: XCTestCase {
private final class ByteToInt32Decoder : ByteToMessageDecoder {
typealias InboundIn = ByteBuffer
typealias InboundOut = Int32
var cumulationBuffer: ByteBuffer?
@ -25,7 +28,7 @@ public class ByteToMessageDecoderTest: XCTestCase {
guard buffer.readableBytes >= MemoryLayout<Int32>.size else {
return false
}
ctx.fireChannelRead(data: .other(buffer.readInteger()! as Int32))
ctx.fireChannelRead(data: self.wrapInboundOut(buffer.readInteger()!))
return true
}
}
@ -40,15 +43,15 @@ public class ByteToMessageDecoderTest: XCTestCase {
let writerIndex = buffer.writerIndex
buffer.moveWriterIndex(to: writerIndex - 1)
channel.pipeline.fireChannelRead(data: .byteBuffer(buffer))
channel.pipeline.fireChannelRead(data: IOData(buffer))
XCTAssertNil(channel.readInbound())
channel.pipeline.fireChannelRead(data: .byteBuffer(buffer.slice(at: writerIndex - 1, length: 1)!))
channel.pipeline.fireChannelRead(data: IOData(buffer.slice(at: writerIndex - 1, length: 1)!))
var buffer2 = channel.allocator.buffer(capacity: 32)
buffer2.write(integer: Int32(2))
buffer2.write(integer: Int32(3))
channel.pipeline.fireChannelRead(data: .byteBuffer(buffer2))
channel.pipeline.fireChannelRead(data: IOData(buffer2))
try channel.close().wait()
@ -62,13 +65,15 @@ public class ByteToMessageDecoderTest: XCTestCase {
public class MessageToByteEncoderTest: XCTestCase {
private final class Int32ToByteEncoder : MessageToByteEncoder {
public func encode(ctx: ChannelHandlerContext, data: IOData, out: inout ByteBuffer) throws {
typealias OutboundIn = Int32
typealias OutboundOut = ByteBuffer
public func encode(ctx: ChannelHandlerContext, data value: Int32, out: inout ByteBuffer) throws {
XCTAssertEqual(MemoryLayout<Int32>.size, out.writableBytes)
let value: Int32 = data.forceAsOther()
out.write(integer: value);
}
public func allocateOutBuffer(ctx: ChannelHandlerContext, data: IOData) throws -> ByteBuffer {
public func allocateOutBuffer(ctx: ChannelHandlerContext, data: Int32) throws -> ByteBuffer {
return ctx.channel!.allocator.buffer(capacity: MemoryLayout<Int32>.size)
}
}
@ -78,8 +83,7 @@ public class MessageToByteEncoderTest: XCTestCase {
_ = try channel.pipeline.add(handler: Int32ToByteEncoder()).wait()
_ = try channel.writeAndFlush(data: .other(Int32(5))).wait()
_ = try channel.writeAndFlush(data: IOData(Int32(5))).wait()
var buffer = channel.readOutbound() as ByteBuffer?
XCTAssertEqual(Int32(5), buffer?.readInteger())

View File

@ -22,7 +22,7 @@ class EmbeddedChannelTest: XCTestCase {
buf.write(string: "hello")
let f = channel.write(data: .byteBuffer(buf))
let f = channel.write(data: IOData(buf))
var ranBlock = false
f.whenSuccess { () -> Void in