HTTPServerUpgradeHandler: Tolerate futures from other ELs (#1134)
This commit is contained in:
parent
fbb977d5c2
commit
8aa84453d4
|
@ -169,12 +169,15 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha
|
|||
// We'll attempt to upgrade. This may take a while, so while we're waiting more data can come in.
|
||||
self.upgradeState = .awaitingUpgrader
|
||||
|
||||
self.handleUpgrade(context: context, request: request, requestedProtocols: requestedProtocols).whenSuccess { callback in
|
||||
if let callback = callback {
|
||||
self.gotUpgrader(upgrader: callback)
|
||||
} else {
|
||||
self.notUpgrading(context: context, data: requestPart)
|
||||
}
|
||||
self.handleUpgrade(context: context, request: request, requestedProtocols: requestedProtocols)
|
||||
.hop(to: context.eventLoop) // the user might return a future from another EventLoop.
|
||||
.whenSuccess { callback in
|
||||
assert(context.eventLoop.inEventLoop)
|
||||
if let callback = callback {
|
||||
self.gotUpgrader(upgrader: callback)
|
||||
} else {
|
||||
self.notUpgrading(context: context, data: requestPart)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ extension HTTPServerUpgradeTestCase {
|
|||
("testUpgradeWithUpgradePayloadInlineWithRequestWorks", testUpgradeWithUpgradePayloadInlineWithRequestWorks),
|
||||
("testDeliversBytesWhenRemovedDuringPartialUpgrade", testDeliversBytesWhenRemovedDuringPartialUpgrade),
|
||||
("testDeliversBytesWhenReentrantlyCalledInChannelReadCompleteOnRemoval", testDeliversBytesWhenReentrantlyCalledInChannelReadCompleteOnRemoval),
|
||||
("testWeTolerateUpgradeFuturesFromWrongEventLoops", testWeTolerateUpgradeFuturesFromWrongEventLoops),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -225,17 +225,31 @@ private class SuccessfulUpgrader: HTTPServerProtocolUpgrader {
|
|||
let supportedProtocol: String
|
||||
let requiredUpgradeHeaders: [String]
|
||||
private let onUpgradeComplete: (HTTPRequestHead) -> ()
|
||||
private let buildUpgradeResponseFuture: (Channel, HTTPHeaders) -> EventLoopFuture<HTTPHeaders>
|
||||
|
||||
public init(forProtocol `protocol`: String, requiringHeaders headers: [String], onUpgradeComplete: @escaping (HTTPRequestHead) -> ()) {
|
||||
public init(forProtocol `protocol`: String,
|
||||
requiringHeaders headers: [String],
|
||||
buildUpgradeResponseFuture: @escaping (Channel, HTTPHeaders) -> EventLoopFuture<HTTPHeaders>,
|
||||
onUpgradeComplete: @escaping (HTTPRequestHead) -> ()) {
|
||||
self.supportedProtocol = `protocol`
|
||||
self.requiredUpgradeHeaders = headers
|
||||
self.onUpgradeComplete = onUpgradeComplete
|
||||
self.buildUpgradeResponseFuture = buildUpgradeResponseFuture
|
||||
}
|
||||
|
||||
public convenience init(forProtocol `protocol`: String,
|
||||
requiringHeaders headers: [String],
|
||||
onUpgradeComplete: @escaping (HTTPRequestHead) -> ()) {
|
||||
self.init(forProtocol: `protocol`,
|
||||
requiringHeaders: headers,
|
||||
buildUpgradeResponseFuture: { $0.eventLoop.makeSucceededFuture($1) },
|
||||
onUpgradeComplete: onUpgradeComplete)
|
||||
}
|
||||
|
||||
public func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture<HTTPHeaders> {
|
||||
var headers = initialResponseHeaders
|
||||
headers.add(name: "X-Upgrade-Complete", value: "true")
|
||||
return channel.eventLoop.makeSucceededFuture(headers)
|
||||
return self.buildUpgradeResponseFuture(channel, headers)
|
||||
}
|
||||
|
||||
public func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture<Void> {
|
||||
|
@ -1349,4 +1363,63 @@ class HTTPServerUpgradeTestCase: XCTestCase {
|
|||
XCTAssertNoThrow(try channel.pipeline.assertDoesNotContainUpgrader())
|
||||
XCTAssertNoThrow(try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self)))
|
||||
}
|
||||
|
||||
func testWeTolerateUpgradeFuturesFromWrongEventLoops() throws {
|
||||
var upgradeRequest: HTTPRequestHead? = nil
|
||||
var upgradeHandlerCbFired = false
|
||||
var upgraderCbFired = false
|
||||
let otherELG = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
defer {
|
||||
XCTAssertNoThrow(try otherELG.syncShutdownGracefully())
|
||||
}
|
||||
|
||||
let upgrader = SuccessfulUpgrader(forProtocol: "myproto",
|
||||
requiringHeaders: ["kafkaesque"],
|
||||
buildUpgradeResponseFuture: {
|
||||
// this is the wrong EL
|
||||
otherELG.next().makeSucceededFuture($1)
|
||||
}) { req in
|
||||
upgradeRequest = req
|
||||
XCTAssert(upgradeHandlerCbFired)
|
||||
upgraderCbFired = true
|
||||
}
|
||||
|
||||
let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader],
|
||||
extraHandlers: []) { (context) in
|
||||
// This is called before the upgrader gets called.
|
||||
XCTAssertNil(upgradeRequest)
|
||||
upgradeHandlerCbFired = true
|
||||
|
||||
// We're closing the connection now.
|
||||
context.close(promise: nil)
|
||||
}
|
||||
defer {
|
||||
XCTAssertNoThrow(try group.syncShutdownGracefully())
|
||||
}
|
||||
|
||||
let completePromise = group.next().makePromise(of: Void.self)
|
||||
let clientHandler = ArrayAccumulationHandler<ByteBuffer> { buffers in
|
||||
let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "")
|
||||
assertResponseIs(response: resultString,
|
||||
expectedResponseLine: "HTTP/1.1 101 Switching Protocols",
|
||||
expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"])
|
||||
completePromise.succeed(())
|
||||
}
|
||||
XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait())
|
||||
|
||||
// This request is safe to upgrade.
|
||||
let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n"
|
||||
XCTAssertNoThrow(try client.writeAndFlush(NIOAny(ByteBuffer.forString(request))).wait())
|
||||
|
||||
// Let the machinery do its thing.
|
||||
XCTAssertNoThrow(try completePromise.futureResult.wait())
|
||||
|
||||
// At this time we want to assert that everything got called. Their own callbacks assert
|
||||
// that the ordering was correct.
|
||||
XCTAssert(upgradeHandlerCbFired)
|
||||
XCTAssert(upgraderCbFired)
|
||||
|
||||
// We also want to confirm that the upgrade handler is no longer in the pipeline.
|
||||
try connectedServer.pipeline.assertDoesNotContainUpgrader()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue