Add support for HTTP Upgrade

This commit is contained in:
Cory Benfield 2017-10-12 17:24:00 -07:00
parent cf352ac0c3
commit b2cbe766f0
6 changed files with 753 additions and 25 deletions

View File

@ -44,10 +44,23 @@ public struct HTTPRequestHead: Equatable {
}
}
public enum HTTPRequestPart {
public enum HTTPRequestPart: Equatable {
case head(HTTPRequestHead)
case body(ByteBuffer)
case end(HTTPHeaders?)
public static func ==(lhs: HTTPRequestPart, rhs: HTTPRequestPart) -> Bool {
switch (lhs, rhs) {
case (.head(let h1), .head(let h2)):
return h1 == h2
case (.body(let b1), .body(let b2)):
return b1 == b2
case (.end(let h1), .end(let h2)):
return h1 == h2
default:
return false
}
}
}
public extension HTTPRequestHead {

View File

@ -0,0 +1,162 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIO
public enum HTTPUpgradeErrors: Error {
case invalidHTTPOrdering
}
public enum HTTPUpgradeEvents {
case upgradeComplete(toProtocol: String, upgradeRequest: HTTPRequestHead)
}
/// An object that implements protocol upgrader knows how to handle HTTP upgrade to
/// a protocol.
public protocol HTTPProtocolUpgrader {
/// The protocol this upgrader knows how to support.
var supportedProtocol: String { get }
/// All the header fields the protocol needs in the request to successfully upgrade. These header fields
/// will be provided to the handler when it is asked to handle the upgrade. They will also be validated
/// against the inbound request's Connection header field.
var requiredUpgradeHeaders: [String] { get }
/// Builds the upgrade response headers. Should return any headers that need to be supplied to the client
/// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should
/// throw.
func buildUpgradeResponse(upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) throws -> HTTPHeaders
/// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline
/// to add whatever channel handlers are required.
func upgrade(ctx: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> Void
}
/// A server-side channel handler that receives HTTP requests and optionally performs a HTTP-upgrade.
/// Removes itself from the channel pipeline after the first inbound request on the connection, regardless of
/// whether the upgrade succeeded or not.
///
/// This handler behaves a bit differently from its Netty counterpart because it does not allow upgrade
/// on any request but the first on a connection. This is primarily to handle clients that pipeline: it's
/// sufficiently difficult to ensure that the upgrade happens at a safe time while dealing with pipelined
/// requests that we choose to punt on it entirely and not allow it. As it happens this is mostly fine:
/// the odds of someone needing to upgrade midway through the lifetime of a connection are very low.
public class HTTPServerUpgradeHandler: ChannelInboundHandler {
public typealias InboundIn = HTTPRequestPart
public typealias InboundOut = HTTPRequestPart
public typealias OutboundOut = HTTPResponsePart
private let upgraders: [String: HTTPProtocolUpgrader]
private let upgradeCompletionHandler: (ChannelHandlerContext) -> Void
/// Whether we've already seen the first request.
private var seenFirstRequest = false
public init(upgraders: [HTTPProtocolUpgrader], upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void) {
var upgraderMap = [String: HTTPProtocolUpgrader]()
for upgrader in upgraders {
upgraderMap[upgrader.supportedProtocol] = upgrader
}
self.upgraders = upgraderMap
self.upgradeCompletionHandler = upgradeCompletionHandler
}
public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
// We're trying to remove ourselves from the pipeline, so just pass this on.
guard !seenFirstRequest else {
ctx.fireChannelRead(data: data)
return
}
let requestPart = unwrapInboundIn(data)
seenFirstRequest = true
// We should only ever see a request header: by the time the body comes in we should
// be out of the pipeline. Anything else is an error.
guard case .head(let request) = requestPart else {
ctx.fireErrorCaught(error: HTTPUpgradeErrors.invalidHTTPOrdering)
notUpgrading(ctx: ctx, data: data)
return
}
// Ok, we have a HTTP request. Check if it's an upgrade. If it's not, we want to pass it on and remove ourselves
// from the channel pipeline.
let requestedProtocols = request.headers.getCanonicalForm("upgrade")
guard requestedProtocols.count > 0 else {
notUpgrading(ctx: ctx, data: data)
return
}
// Cool, this is an upgrade! Let's go.
if !handleUpgrade(ctx: ctx, request: request, requestedProtocols: requestedProtocols) {
notUpgrading(ctx: ctx, data: data)
}
}
/// The core of the upgrade handling logic.
private func handleUpgrade(ctx: ChannelHandlerContext, request: HTTPRequestHead, requestedProtocols: [String]) -> Bool {
let connectionHeader = Set(request.headers.getCanonicalForm("connection").map { $0.lowercased() })
let allHeaderNames = Set(request.headers.map { $0.name.lowercased() })
for proto in requestedProtocols {
guard let upgrader = upgraders[proto] else {
continue
}
let requiredHeaders = Set(upgrader.requiredUpgradeHeaders)
guard requiredHeaders.isSubset(of: allHeaderNames) && requiredHeaders.isSubset(of: connectionHeader) else {
continue
}
var responseHeaders = buildUpgradeHeaders(protocol: proto)
do {
responseHeaders = try upgrader.buildUpgradeResponse(upgradeRequest: request, initialResponseHeaders: responseHeaders)
} catch {
// We should fire this error so the user can log it, but keep going.
ctx.fireErrorCaught(error: error)
continue
}
sendUpgradeResponse(ctx: ctx, upgradeRequest: request, responseHeaders: responseHeaders).whenSuccess {
self.upgradeCompletionHandler(ctx)
upgrader.upgrade(ctx: ctx, upgradeRequest: request)
ctx.fireUserInboundEventTriggered(event: HTTPUpgradeEvents.upgradeComplete(toProtocol: proto, upgradeRequest: request))
let _ = ctx.pipeline!.remove(ctx: ctx)
}
return true
}
return false
}
/// Sends the 101 Switching Protocols response for the pipeline.
private func sendUpgradeResponse(ctx: ChannelHandlerContext, upgradeRequest: HTTPRequestHead, responseHeaders: HTTPHeaders) -> Future<Void> {
var response = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .switchingProtocols)
response.headers = responseHeaders
return ctx.writeAndFlush(data: wrapOutboundOut(HTTPResponsePart.head(response)))
}
/// Called when we know we're not upgrading. Passes the data on and then removes this object from the pipeline.
private func notUpgrading(ctx: ChannelHandlerContext, data: IOData) {
ctx.fireChannelRead(data: data)
let _ = ctx.pipeline!.remove(ctx: ctx)
}
/// Builds the initial mandatory HTTP headers for HTTP ugprade responses.
private func buildUpgradeHeaders(`protocol`: String) -> HTTPHeaders {
return HTTPHeaders([("connection", "upgrade"), ("upgrade", `protocol`)])
}
}

View File

@ -42,6 +42,7 @@ import XCTest
testCase(HTTPHeadersTest.allTests),
testCase(HTTPServerClientTest.allTests),
testCase(HTTPTest.allTests),
testCase(HTTPUpgradeTestCase.allTests),
testCase(MarkedCircularBufferTests.allTests),
testCase(MessageToByteEncoderTest.allTests),
testCase(OpenSSLIntegrationTest.allTests),

View File

@ -44,6 +44,30 @@ extension Array where Array.Element == ByteBuffer {
}
}
internal class ArrayAccumulationHandler<T>: ChannelInboundHandler {
typealias InboundIn = T
private var receiveds: [T] = []
private var allDoneBlock: DispatchWorkItem! = nil
public init(completion: @escaping ([T]) -> Void) {
self.allDoneBlock = DispatchWorkItem { [unowned self] () -> Void in
completion(self.receiveds)
}
}
public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
self.receiveds.append(self.unwrapInboundIn(data))
}
public func channelUnregistered(ctx: ChannelHandlerContext) {
self.allDoneBlock.perform()
}
public func syncWaitForCompletion() {
self.allDoneBlock.wait()
}
}
class HTTPServerClientTest : XCTestCase {
/* needs to be something reasonably large and odd so it has good odds producing incomplete writes even on the loopback interface */
@ -171,30 +195,6 @@ class HTTPServerClientTest : XCTestCase {
}
}
}
private class ArrayAccumulationHandler<T>: ChannelInboundHandler {
typealias InboundIn = T
private var receiveds: [T] = []
private var allDoneBlock: DispatchWorkItem! = nil
public init(completion: @escaping ([T]) -> Void) {
self.allDoneBlock = DispatchWorkItem { [unowned self] () -> Void in
completion(self.receiveds)
}
}
public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
self.receiveds.append(self.unwrapInboundIn(data))
}
public func channelUnregistered(ctx: ChannelHandlerContext) {
self.allDoneBlock.perform()
}
public func syncWaitForCompletion() {
self.allDoneBlock.wait()
}
}
func testSimpleGet() throws {
let group = try MultiThreadedEventLoopGroup(numThreads: 1)

View File

@ -0,0 +1,42 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
///
/// HTTPUpgradeTests+XCTest.swift
///
import XCTest
///
/// NOTE: This file was generated by generate_linux_tests.rb
///
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
///
extension HTTPUpgradeTestCase {
static var allTests : [(String, (HTTPUpgradeTestCase) -> () throws -> Void)] {
return [
("testUpgradeWithoutUpgrade", testUpgradeWithoutUpgrade),
("testUpgradeAfterInitialRequest", testUpgradeAfterInitialRequest),
("testUpgradeHandlerBarfsOnUnexpectedOrdering", testUpgradeHandlerBarfsOnUnexpectedOrdering),
("testSimpleUpgradeSucceeds", testSimpleUpgradeSucceeds),
("testUpgradeRequiresCorrectHeaders", testUpgradeRequiresCorrectHeaders),
("testUpgradeRequiresHeadersInConnection", testUpgradeRequiresHeadersInConnection),
("testUpgradeOnlyHandlesKnownProtocols", testUpgradeOnlyHandlesKnownProtocols),
("testUpgradeRespectsClientPreference", testUpgradeRespectsClientPreference),
("testUpgradeFiresUserEvent", testUpgradeFiresUserEvent),
("testUpgraderCanRejectUpgradeForPersonalReasons", testUpgraderCanRejectUpgradeForPersonalReasons),
]
}
}

View File

@ -0,0 +1,510 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import XCTest
@testable import NIO
@testable import NIOHTTP1
private extension ChannelPipeline {
func assertDoesNotContain(handler: ChannelHandler) throws {
do {
let _ = try self.context(handler: handler).wait()
XCTFail("Found handler")
} catch ChannelPipelineError.notFound {
// Nothing to see here
}
}
}
private func serverHTTPChannel(group: EventLoopGroup, handlers: [ChannelHandler]) -> Channel {
return try! ServerBootstrap(group: group)
.option(option: ChannelOptions.Socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.handler(childHandler: ChannelInitializer(initChannel: { channel in
channel.pipeline.add(handler: HTTPRequestDecoder()).then {
channel.pipeline.add(handler: HTTPResponseEncoder(allocator: channel.allocator)).then {
let futureResults = handlers.map { channel.pipeline.add(handler: $0) }
return Future<Void>.andAll(futureResults, eventLoop: channel.eventLoop)
}
}
})).bind(to: "127.0.0.1", on: 0).wait()
}
private func connectedClientChannel(group: EventLoopGroup, serverAddress: SocketAddress) -> Channel {
return try! ClientBootstrap(group: group)
.connect(to: serverAddress)
.wait()
}
private func setUpTest(withHandlers handlers: [ChannelHandler]) -> (EventLoopGroup, Channel, Channel) {
let group = try! MultiThreadedEventLoopGroup(numThreads: 1)
let serverChannel = serverHTTPChannel(group: group, handlers: handlers)
let clientChannel = connectedClientChannel(group: group, serverAddress: serverChannel.localAddress!)
return (group, serverChannel, clientChannel)
}
private func assertResponseIs(response: String, expectedResponseLine: String, expectedResponseHeaders: [String]) {
var lines = response.split(separator: "\r\n", omittingEmptySubsequences: false).map { String($0) }
// We never expect a response body here. This means we need the last two entries to be empty strings.
XCTAssertEqual("", lines.removeLast())
XCTAssertEqual("", lines.removeLast())
// Check the response line is correct.
let actualResponseLine = lines.removeFirst()
XCTAssertEqual(expectedResponseLine, actualResponseLine)
// For each header, find it in the actual response headers and remove it.
for expectedHeader in expectedResponseHeaders {
guard let index = lines.index(of: expectedHeader) else {
XCTFail("Could not find header \"\(expectedHeader)\"")
return
}
lines.remove(at: index)
}
// That should be all the headers.
XCTAssertEqual(lines.count, 0)
}
private class ExplodingUpgrader: HTTPProtocolUpgrader {
let supportedProtocol: String
let requiredUpgradeHeaders: [String]
private enum Explosion: Error {
case KABOOM
}
public init(forProtocol `protocol`: String, requiringHeaders: [String] = []) {
self.supportedProtocol = `protocol`
self.requiredUpgradeHeaders = requiringHeaders
}
public func buildUpgradeResponse(upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) throws -> HTTPHeaders {
XCTFail("buildUpgradeResponse called")
throw Explosion.KABOOM
}
public func upgrade(ctx: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) {
XCTFail("upgrade called")
}
}
private class UpgraderSaysNo: HTTPProtocolUpgrader {
let supportedProtocol: String
let requiredUpgradeHeaders: [String] = []
public enum No: Error {
case no
}
public init(forProtocol `protocol`: String) {
self.supportedProtocol = `protocol`
}
public func buildUpgradeResponse(upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) throws -> HTTPHeaders {
throw No.no
}
public func upgrade(ctx: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) {
XCTFail("upgrade called")
}
}
private class SuccessfulUpgrader: HTTPProtocolUpgrader {
let supportedProtocol: String
let requiredUpgradeHeaders: [String]
private let onUpgradeComplete: (HTTPRequestHead) -> ()
public init(forProtocol `protocol`: String, requiringHeaders headers: [String], onUpgradeComplete: @escaping (HTTPRequestHead) -> ()) {
self.supportedProtocol = `protocol`
self.requiredUpgradeHeaders = headers
self.onUpgradeComplete = onUpgradeComplete
}
public func buildUpgradeResponse(upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) throws -> HTTPHeaders {
var headers = initialResponseHeaders
headers.add(name: "X-Upgrade-Complete", value: "true")
return headers
}
public func upgrade(ctx: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) {
self.onUpgradeComplete(upgradeRequest)
}
}
private class UserEventSaver<EventType>: ChannelInboundHandler {
public typealias InboundIn = Any
public typealias InboundUserEventIn = EventType
public var events: [EventType] = []
public func userInboundEventTriggered(ctx: ChannelHandlerContext, event: Any) {
events.append(unwrapInboundUserEventIn(event))
ctx.fireUserInboundEventTriggered(event: event)
}
}
private class ErrorSaver: ChannelInboundHandler {
public typealias InboundIn = Any
public typealias InboundOut = Any
public var errors: [Error] = []
public func errorCaught(ctx: ChannelHandlerContext, error: Error) {
errors.append(error)
ctx.fireErrorCaught(error: error)
}
}
private extension ByteBuffer {
static func forString(_ string: String) -> ByteBuffer {
var buf = ByteBufferAllocator().buffer(capacity: string.utf8.count)
buf.write(string: string)
return buf
}
}
class HTTPUpgradeTestCase: XCTestCase {
func testUpgradeWithoutUpgrade() throws {
let handler = HTTPServerUpgradeHandler(upgraders: [ExplodingUpgrader(forProtocol: "myproto")]) { _ in
XCTFail("upgrade completed")
}
let (group, server, client) = setUpTest(withHandlers: [handler])
defer {
try! client.close().wait()
try! server.close().wait()
try! group.syncShutdownGracefully()
}
let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\n\r\n"
try! client.writeAndFlush(data: IOData(ByteBuffer.forString(request))).wait()
// At this time the channel pipeline should not contain our handler: it should have removed itself.
try client.pipeline.assertDoesNotContain(handler: handler)
}
func testUpgradeAfterInitialRequest() throws {
let handler = HTTPServerUpgradeHandler(upgraders: [ExplodingUpgrader(forProtocol: "myproto")]) { _ in
XCTFail("upgrade completed")
}
let (group, server, client) = setUpTest(withHandlers: [handler])
defer {
try! client.close().wait()
try! server.close().wait()
try! group.syncShutdownGracefully()
}
// This request fires a subsequent upgrade in immediately. It should also be ignored.
let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\n\r\nOPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\n"
try! client.writeAndFlush(data: IOData(ByteBuffer.forString(request))).wait()
// At this time the channel pipeline should not contain our handler: it should have removed itself.
try client.pipeline.assertDoesNotContain(handler: handler)
}
func testUpgradeHandlerBarfsOnUnexpectedOrdering() throws {
let channel = EmbeddedChannel()
defer {
XCTAssertFalse(try! channel.finish())
}
let handler = HTTPServerUpgradeHandler(upgraders: [ExplodingUpgrader(forProtocol: "myproto")]) { _ in
XCTFail("upgrade completed")
}
let data = HTTPRequestPart.body(ByteBuffer.forString("hello"))
try! channel.pipeline.add(handler: handler).wait()
do {
try channel.writeInbound(data: data)
XCTFail("Writing of bad data did not error")
} catch HTTPUpgradeErrors.invalidHTTPOrdering {
// Nothing to see here.
}
// The handler removed itself from the pipeline and passed the unexpected
// data on.
try channel.pipeline.assertDoesNotContain(handler: handler)
let receivedData: HTTPRequestPart = channel.readInbound()!
XCTAssertEqual(data, receivedData)
}
func testSimpleUpgradeSucceeds() throws {
var upgradeRequest: HTTPRequestHead? = nil
var upgradeHandlerCbFired = false
var upgraderCbFired = false
let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in
upgradeRequest = req
XCTAssert(upgradeHandlerCbFired)
upgraderCbFired = true
}
let handler = HTTPServerUpgradeHandler(upgraders: [upgrader]) { ctx in
// This is called before the upgrader gets called.
XCTAssertNil(upgradeRequest)
upgradeHandlerCbFired = true
// We're closing the connection now.
ctx.close(promise: nil)
}
let (group, server, client) = setUpTest(withHandlers: [handler])
defer {
try! group.syncShutdownGracefully()
}
let completePromise: Promise<Void> = group.next().newPromise()
let clientHandler = ArrayAccumulationHandler<ByteBuffer> { buffers in
let resultString = buffers.map { $0.string(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "")
assertResponseIs(response: resultString,
expectedResponseLine: "HTTP/1.1 101 Switching Protocols",
expectedResponseHeaders: ["x-upgrade-complete: true", "upgrade: myproto", "connection: upgrade"])
completePromise.succeed(result: ())
}
try! client.pipeline.add(handler: clientHandler).wait()
// This request is safe to upgrade.
let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n"
try! client.writeAndFlush(data: IOData(ByteBuffer.forString(request))).wait()
// Let the machinery do its thing.
try! completePromise.futureResult.wait()
// At this time we want to assert that everything got called. Their own callbacks assert
// that the ordering was correct.
XCTAssert(upgradeHandlerCbFired)
XCTAssert(upgraderCbFired)
// We also want to confirm that the upgrade handler is no longer in the pipeline.
try client.pipeline.assertDoesNotContain(handler: handler)
}
func testUpgradeRequiresCorrectHeaders() throws {
let handler = HTTPServerUpgradeHandler(upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])]) { _ in
XCTFail("upgrade completed")
}
let (group, server, client) = setUpTest(withHandlers: [handler])
defer {
try! client.close().wait()
try! server.close().wait()
try! group.syncShutdownGracefully()
}
let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nConnection: upgrade\r\nUpgrade: myproto\r\n\r\n"
try! client.writeAndFlush(data: IOData(ByteBuffer.forString(request))).wait()
// At this time the channel pipeline should not contain our handler: it should have removed itself.
try client.pipeline.assertDoesNotContain(handler: handler)
}
func testUpgradeRequiresHeadersInConnection() throws {
let handler = HTTPServerUpgradeHandler(upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])]) { _ in
XCTFail("upgrade completed")
}
let (group, server, client) = setUpTest(withHandlers: [handler])
defer {
try! client.close().wait()
try! server.close().wait()
try! group.syncShutdownGracefully()
}
let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nConnection: upgrade\r\nUpgrade: myproto\r\nKafkaesque: true\r\n\r\n"
try! client.writeAndFlush(data: IOData(ByteBuffer.forString(request))).wait()
// At this time the channel pipeline should not contain our handler: it should have removed itself.
try client.pipeline.assertDoesNotContain(handler: handler)
}
func testUpgradeOnlyHandlesKnownProtocols() throws {
let handler = HTTPServerUpgradeHandler(upgraders: [ExplodingUpgrader(forProtocol: "myproto")]) { _ in
XCTFail("upgrade completed")
}
let (group, server, client) = setUpTest(withHandlers: [handler])
defer {
try! client.close().wait()
try! server.close().wait()
try! group.syncShutdownGracefully()
}
let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nConnection: upgrade\r\nUpgrade: something-else\r\n\r\n"
try! client.writeAndFlush(data: IOData(ByteBuffer.forString(request))).wait()
// At this time the channel pipeline should not contain our handler: it should have removed itself.
try client.pipeline.assertDoesNotContain(handler: handler)
}
func testUpgradeRespectsClientPreference() throws {
var upgradeRequest: HTTPRequestHead? = nil
var upgradeHandlerCbFired = false
var upgraderCbFired = false
let explodingUpgrader = ExplodingUpgrader(forProtocol: "exploder")
let successfulUpgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in
upgradeRequest = req
XCTAssert(upgradeHandlerCbFired)
upgraderCbFired = true
}
let handler = HTTPServerUpgradeHandler(upgraders: [explodingUpgrader, successfulUpgrader]) { ctx in
// This is called before the upgrader gets called.
XCTAssertNil(upgradeRequest)
upgradeHandlerCbFired = true
// We're closing the connection now.
ctx.close(promise: nil)
}
let (group, server, client) = setUpTest(withHandlers: [handler])
defer {
try! group.syncShutdownGracefully()
}
let completePromise: Promise<Void> = group.next().newPromise()
let clientHandler = ArrayAccumulationHandler<ByteBuffer> { buffers in
let resultString = buffers.map { $0.string(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "")
assertResponseIs(response: resultString,
expectedResponseLine: "HTTP/1.1 101 Switching Protocols",
expectedResponseHeaders: ["x-upgrade-complete: true", "upgrade: myproto", "connection: upgrade"])
completePromise.succeed(result: ())
}
try! client.pipeline.add(handler: clientHandler).wait()
// This request is safe to upgrade.
let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto, exploder\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n"
try! client.writeAndFlush(data: IOData(ByteBuffer.forString(request))).wait()
// Let the machinery do its thing.
try! completePromise.futureResult.wait()
// At this time we want to assert that everything got called. Their own callbacks assert
// that the ordering was correct.
XCTAssert(upgradeHandlerCbFired)
XCTAssert(upgraderCbFired)
// We also want to confirm that the upgrade handler is no longer in the pipeline.
try client.pipeline.assertDoesNotContain(handler: handler)
}
func testUpgradeFiresUserEvent() throws {
// The user event is fired last, so we don't see it until both other callbacks
// have fired.
let eventSaver = UserEventSaver<HTTPUpgradeEvents>()
let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: []) { req in
XCTAssertEqual(eventSaver.events.count, 0)
}
let handler = HTTPServerUpgradeHandler(upgraders: [upgrader]) { ctx in
XCTAssertEqual(eventSaver.events.count, 0)
ctx.close(promise: nil)
}
let (group, server, client) = setUpTest(withHandlers: [handler, eventSaver])
defer {
try! group.syncShutdownGracefully()
}
let completePromise: Promise<Void> = group.next().newPromise()
let clientHandler = ArrayAccumulationHandler<ByteBuffer> { buffers in
let resultString = buffers.map { $0.string(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "")
assertResponseIs(response: resultString,
expectedResponseLine: "HTTP/1.1 101 Switching Protocols",
expectedResponseHeaders: ["x-upgrade-complete: true", "upgrade: myproto", "connection: upgrade"])
completePromise.succeed(result: ())
}
try! client.pipeline.add(handler: clientHandler).wait()
// This request is safe to upgrade.
let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade,kafkaesque\r\n\r\n"
try! client.writeAndFlush(data: IOData(ByteBuffer.forString(request))).wait()
// Let the machinery do its thing.
try! completePromise.futureResult.wait()
// At this time we should have received one user event. We schedule this onto the
// event loop to guarantee thread safety.
try! group.next().scheduleTask(in: .nanoseconds(0)) {
XCTAssertEqual(eventSaver.events.count, 1)
if case .upgradeComplete(let proto, let req) = eventSaver.events[0] {
XCTAssertEqual(proto, "myproto")
XCTAssertEqual(req.method, .OPTIONS)
XCTAssertEqual(req.uri, "*")
XCTAssertEqual(req.version, HTTPVersion(major: 1, minor: 1))
} else {
XCTFail("Unexpected event: \(eventSaver.events[0])")
}
}.futureResult.wait()
// We also want to confirm that the upgrade handler is no longer in the pipeline.
try client.pipeline.assertDoesNotContain(handler: handler)
}
func testUpgraderCanRejectUpgradeForPersonalReasons() throws {
var upgradeRequest: HTTPRequestHead? = nil
var upgradeHandlerCbFired = false
var upgraderCbFired = false
let explodingUpgrader = UpgraderSaysNo(forProtocol: "noproto")
let successfulUpgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in
upgradeRequest = req
XCTAssert(upgradeHandlerCbFired)
upgraderCbFired = true
}
let handler = HTTPServerUpgradeHandler(upgraders: [explodingUpgrader, successfulUpgrader]) { ctx in
// This is called before the upgrader gets called.
XCTAssertNil(upgradeRequest)
upgradeHandlerCbFired = true
// We're closing the connection now.
ctx.close(promise: nil)
}
let errorCatcher = ErrorSaver()
let (group, server, client) = setUpTest(withHandlers: [handler, errorCatcher])
defer {
try! group.syncShutdownGracefully()
}
let completePromise: Promise<Void> = group.next().newPromise()
let clientHandler = ArrayAccumulationHandler<ByteBuffer> { buffers in
let resultString = buffers.map { $0.string(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "")
assertResponseIs(response: resultString,
expectedResponseLine: "HTTP/1.1 101 Switching Protocols",
expectedResponseHeaders: ["x-upgrade-complete: true", "upgrade: myproto", "connection: upgrade"])
completePromise.succeed(result: ())
}
try! client.pipeline.add(handler: clientHandler).wait()
// This request is safe to upgrade.
let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: noproto,myproto\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n"
try! client.writeAndFlush(data: IOData(ByteBuffer.forString(request))).wait()
// Let the machinery do its thing.
try! completePromise.futureResult.wait()
// At this time we want to assert that everything got called. Their own callbacks assert
// that the ordering was correct.
XCTAssert(upgradeHandlerCbFired)
XCTAssert(upgraderCbFired)
// We also want to confirm that the upgrade handler is no longer in the pipeline.
try client.pipeline.assertDoesNotContain(handler: handler)
// And we want to confirm we saved the error.
XCTAssertEqual(errorCatcher.errors.count, 1)
switch(errorCatcher.errors[0]) {
case UpgraderSaysNo.No.no:
break
default:
XCTFail("Unexpected error: \(errorCatcher.errors[0])")
}
}
}