swift-nio/Sources/NIOHTTP1/HTTPServerUpgradeHandler.swift

348 lines
18 KiB
Swift

//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIOCore
/// Errors that may be raised by the `HTTPServerProtocolUpgrader`.
public enum HTTPServerUpgradeErrors: Error {
case invalidHTTPOrdering
}
/// User events that may be fired by the `HTTPServerProtocolUpgrader`.
public enum HTTPServerUpgradeEvents: Sendable {
/// Fired when HTTP upgrade has completed and the
/// `HTTPServerProtocolUpgrader` is about to remove itself from the
/// `ChannelPipeline`.
case upgradeComplete(toProtocol: String, upgradeRequest: HTTPRequestHead)
}
/// An object that implements `HTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to
/// a protocol on a server-side channel.
public protocol HTTPServerProtocolUpgrader {
/// The protocol this upgrader knows how to support.
var supportedProtocol: String { get }
/// All the header fields the protocol needs in the request to successfully upgrade. These header fields
/// will be provided to the handler when it is asked to handle the upgrade. They will also be validated
/// against the inbound request's Connection header field.
var requiredUpgradeHeaders: [String] { get }
/// Builds the upgrade response headers. Should return any headers that need to be supplied to the client
/// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should
/// fail the future.
func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture<HTTPHeaders>
/// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline
/// to add whatever channel handlers are required. Until the returned `EventLoopFuture` succeeds, all received
/// data will be buffered.
func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture<Void>
}
/// A server-side channel handler that receives HTTP requests and optionally performs a HTTP-upgrade.
/// Removes itself from the channel pipeline after the first inbound request on the connection, regardless of
/// whether the upgrade succeeded or not.
///
/// This handler behaves a bit differently from its Netty counterpart because it does not allow upgrade
/// on any request but the first on a connection. This is primarily to handle clients that pipeline: it's
/// sufficiently difficult to ensure that the upgrade happens at a safe time while dealing with pipelined
/// requests that we choose to punt on it entirely and not allow it. As it happens this is mostly fine:
/// the odds of someone needing to upgrade midway through the lifetime of a connection are very low.
public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableChannelHandler {
public typealias InboundIn = HTTPServerRequestPart
public typealias InboundOut = HTTPServerRequestPart
public typealias OutboundOut = HTTPServerResponsePart
private let upgraders: [String: HTTPServerProtocolUpgrader]
private let upgradeCompletionHandler: (ChannelHandlerContext) -> Void
private let httpEncoder: HTTPResponseEncoder
private let extraHTTPHandlers: [RemovableChannelHandler]
/// Whether we've already seen the first request.
private var seenFirstRequest = false
/// The closure that should be invoked when the end of the upgrade request is received if any.
private var upgradeState: UpgradeState = .idle
private var receivedMessages: CircularBuffer<NIOAny> = CircularBuffer()
/// Create a `HTTPServerUpgradeHandler`.
///
/// - Parameter upgraders: All `HTTPServerProtocolUpgrader` objects that this pipeline will be able
/// to use to handle HTTP upgrade.
/// - Parameter httpEncoder: The `HTTPResponseEncoder` encoding responses from this handler and which will
/// be removed from the pipeline once the upgrade response is sent. This is used to ensure
/// that the pipeline will be in a clean state after upgrade.
/// - Parameter extraHTTPHandlers: Any other handlers that are directly related to handling HTTP. At the very least
/// this should include the `HTTPDecoder`, but should also include any other handler that cannot tolerate
/// receiving non-HTTP data.
/// - Parameter upgradeCompletionHandler: A block that will be fired when HTTP upgrade is complete.
public init(
upgraders: [HTTPServerProtocolUpgrader],
httpEncoder: HTTPResponseEncoder,
extraHTTPHandlers: [RemovableChannelHandler],
upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void
) {
var upgraderMap = [String: HTTPServerProtocolUpgrader]()
for upgrader in upgraders {
upgraderMap[upgrader.supportedProtocol.lowercased()] = upgrader
}
self.upgraders = upgraderMap
self.upgradeCompletionHandler = upgradeCompletionHandler
self.httpEncoder = httpEncoder
self.extraHTTPHandlers = extraHTTPHandlers
}
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
guard !self.seenFirstRequest else {
// We're waiting for upgrade to complete: buffer this data.
self.receivedMessages.append(data)
return
}
let requestPart = unwrapInboundIn(data)
switch self.upgradeState {
case .idle:
self.firstRequestHeadReceived(context: context, requestPart: requestPart)
case .awaitingUpgrader:
if case .end = requestPart {
// This is the end of the first request. Swallow it, we're buffering the rest.
self.seenFirstRequest = true
}
case .upgraderReady(let upgrade):
if case .end = requestPart {
// This is the end of the first request, and we can upgrade. Time to kick it off.
self.seenFirstRequest = true
upgrade()
}
case .upgradeFailed:
// We were re-entrantly called while delivering the request head. We can just pass this through.
context.fireChannelRead(data)
case .upgradeComplete:
preconditionFailure("Upgrade has completed but we have not seen a whole request and still got re-entrantly called.")
case .upgrading:
preconditionFailure("We think we saw .end before and began upgrading, but somehow we have not set seenFirstRequest")
}
}
public func removeHandler(context: ChannelHandlerContext, removalToken: ChannelHandlerContext.RemovalToken) {
// We have been formally removed from the pipeline. We should send any buffered data we have.
// Note that we loop twice. This is because we want to guard against being reentrantly called from fireChannelReadComplete.
while self.receivedMessages.count > 0 {
while self.receivedMessages.count > 0 {
let bufferedPart = self.receivedMessages.removeFirst()
context.fireChannelRead(bufferedPart)
}
context.fireChannelReadComplete()
}
context.leavePipeline(removalToken: removalToken)
}
private func firstRequestHeadReceived(context: ChannelHandlerContext, requestPart: HTTPServerRequestPart) {
// We should decide if we're going to upgrade based on the first request header: if we aren't upgrading,
// by the time the body comes in we should be out of the pipeline. That means that if we don't think we're
// upgrading, the only thing we should see is a request head. Anything else in an error.
guard case .head(let request) = requestPart else {
context.fireErrorCaught(HTTPServerUpgradeErrors.invalidHTTPOrdering)
self.notUpgrading(context: context, data: requestPart)
return
}
// Ok, we have a HTTP request. Check if it's an upgrade. If it's not, we want to pass it on and remove ourselves
// from the channel pipeline.
let requestedProtocols = request.headers[canonicalForm: "upgrade"].map(String.init)
guard requestedProtocols.count > 0 else {
self.notUpgrading(context: context, data: requestPart)
return
}
// Cool, this is an upgrade request.
// We'll attempt to upgrade. This may take a while, so while we're waiting more data can come in.
self.upgradeState = .awaitingUpgrader
self.handleUpgrade(context: context, request: request, requestedProtocols: requestedProtocols)
.hop(to: context.eventLoop) // the user might return a future from another EventLoop.
.whenSuccess { callback in
context.eventLoop.assertInEventLoop()
if let callback = callback {
self.gotUpgrader(upgrader: callback)
} else {
self.notUpgrading(context: context, data: requestPart)
}
}
}
/// The core of the upgrade handling logic.
///
/// - returns: An `EventLoopFuture` that will contain a callback to invoke if upgrade is requested, or nil if upgrade has failed. Never returns a failed future.
private func handleUpgrade(context: ChannelHandlerContext, request: HTTPRequestHead, requestedProtocols: [String]) -> EventLoopFuture<(() -> Void)?> {
let connectionHeader = Set(request.headers[canonicalForm: "connection"].map { $0.lowercased() })
let allHeaderNames = Set(request.headers.map { $0.name.lowercased() })
// We now set off a chain of Futures to try to find a protocol upgrade. While this is blocking, we need to buffer inbound data.
let protocolIterator = requestedProtocols.makeIterator()
return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader)
}
/// Attempt to upgrade a single protocol.
///
/// Will recurse through `protocolIterator` if upgrade fails.
private func handleUpgradeForProtocol(context: ChannelHandlerContext, protocolIterator: Array<String>.Iterator, request: HTTPRequestHead, allHeaderNames: Set<String>, connectionHeader: Set<String>) -> EventLoopFuture<(() -> Void)?> {
// We want a local copy of the protocol iterator. We'll pass it to the next invocation of the function.
var protocolIterator = protocolIterator
guard let proto = protocolIterator.next() else {
// We're done! No suitable protocol for upgrade.
return context.eventLoop.makeSucceededFuture(nil)
}
guard let upgrader = self.upgraders[proto.lowercased()] else {
return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader)
}
let requiredHeaders = Set(upgrader.requiredUpgradeHeaders.map { $0.lowercased() })
guard requiredHeaders.isSubset(of: allHeaderNames) && requiredHeaders.isSubset(of: connectionHeader) else {
return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader)
}
let responseHeaders = self.buildUpgradeHeaders(protocol: proto)
return upgrader.buildUpgradeResponse(channel: context.channel, upgradeRequest: request, initialResponseHeaders: responseHeaders).map { finalResponseHeaders in
return {
// Ok, we're upgrading.
self.upgradeState = .upgrading
// Before we finish the upgrade we have to remove the HTTPDecoder and any other non-Encoder HTTP
// handlers from the pipeline, to prevent them parsing any more data. We'll buffer the data until
// that completes.
// While there are a lot of Futures involved here it's quite possible that all of this code will
// actually complete synchronously: we just want to program for the possibility that it won't.
// Once that's done, we send the upgrade response, then remove the HTTP encoder, then call the
// internal handler, then call the user code, and then finally when the user code is done we do
// our final cleanup steps, namely we replay the received data we buffered in the meantime and
// then remove ourselves from the pipeline.
self.removeExtraHandlers(context: context).flatMap {
self.sendUpgradeResponse(context: context, upgradeRequest: request, responseHeaders: finalResponseHeaders)
}.flatMap {
context.pipeline.removeHandler(self.httpEncoder)
}.flatMap { () -> EventLoopFuture<Void> in
self.upgradeCompletionHandler(context)
return upgrader.upgrade(context: context, upgradeRequest: request)
}.whenComplete { result in
switch result {
case .success:
context.fireUserInboundEventTriggered(HTTPServerUpgradeEvents.upgradeComplete(toProtocol: proto, upgradeRequest: request))
self.upgradeState = .upgradeComplete
// When we remove ourselves we'll be delivering any buffered data.
context.pipeline.removeHandler(context: context, promise: nil)
case .failure(let error):
// Remain in the '.upgrading' state.
context.fireErrorCaught(error)
}
}
}
}.flatMapError { error in
// No upgrade here. We want to fire the error down the pipeline, and then try another loop iteration.
context.fireErrorCaught(error)
return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader)
}
}
private func gotUpgrader(upgrader: @escaping (() -> Void)) {
switch self.upgradeState {
case .awaitingUpgrader:
self.upgradeState = .upgraderReady(upgrader)
if self.seenFirstRequest {
// Ok, we're good to go, we can upgrade. Otherwise we're waiting for .end, which
// will trigger the upgrade.
upgrader()
}
case .idle, .upgradeComplete, .upgraderReady, .upgradeFailed, .upgrading:
preconditionFailure("Unexpected upgrader state: \(self.upgradeState)")
}
}
/// Sends the 101 Switching Protocols response for the pipeline.
private func sendUpgradeResponse(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead, responseHeaders: HTTPHeaders) -> EventLoopFuture<Void> {
var response = HTTPResponseHead(version: .http1_1, status: .switchingProtocols)
response.headers = responseHeaders
return context.writeAndFlush(wrapOutboundOut(HTTPServerResponsePart.head(response)))
}
/// Called when we know we're not upgrading. Passes the data on and then removes this object from the pipeline.
private func notUpgrading(context: ChannelHandlerContext, data: HTTPServerRequestPart) {
self.upgradeState = .upgradeFailed
if !self.seenFirstRequest {
// We haven't seen the first request .end. That means we're not buffering anything, and we can
// just deliver data.
assert(self.receivedMessages.isEmpty)
context.fireChannelRead(self.wrapInboundOut(data))
} else {
// This is trickier. We've seen the first request .end, so we now need to deliver the .head we
// got passed, as well as the .end we swallowed, and any buffered parts. While we're doing this
// we may be re-entrantly called, which will cause us to buffer new parts. To make that safe, we
// must ensure we aren't holding the buffer mutably, so no for loop for us.
context.fireChannelRead(self.wrapInboundOut(data))
context.fireChannelRead(self.wrapInboundOut(.end(nil)))
}
context.fireChannelReadComplete()
// Ok, we've delivered all the parts. We can now remove ourselves, which should happen synchronously.
context.pipeline.removeHandler(context: context, promise: nil)
}
/// Builds the initial mandatory HTTP headers for HTTP upgrade responses.
private func buildUpgradeHeaders(`protocol`: String) -> HTTPHeaders {
return HTTPHeaders([("connection", "upgrade"), ("upgrade", `protocol`)])
}
/// Removes any extra HTTP-related handlers from the channel pipeline.
private func removeExtraHandlers(context: ChannelHandlerContext) -> EventLoopFuture<Void> {
guard self.extraHTTPHandlers.count > 0 else {
return context.eventLoop.makeSucceededFuture(())
}
return .andAllSucceed(self.extraHTTPHandlers.map { context.pipeline.removeHandler($0) },
on: context.eventLoop)
}
}
extension HTTPServerUpgradeHandler: @unchecked Sendable {}
extension HTTPServerUpgradeHandler {
/// The state of the upgrade handler.
private enum UpgradeState {
/// Awaiting some activity.
case idle
/// The request head has been received. We're currently running the future chain awaiting an upgrader.
case awaitingUpgrader
/// We have an upgrader, which means we can begin upgrade.
case upgraderReady(() -> Void)
/// The upgrade is in process.
case upgrading
/// The upgrade has failed, and we are being removed from the pipeline.
case upgradeFailed
/// The upgrade has succeeded, and we are being removed from the pipeline.
case upgradeComplete
}
}