From ec30e5cc5a90a7da59a284971127a822a0aa9663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Wei=C3=9F?= Date: Tue, 24 Apr 2018 08:18:21 +0100 Subject: [PATCH] fix recursive channelReads in pipelining handler (#348) Motivation: The pipelining handler made the assumption that `channelRead` is never called recursively. That's mostly true but there is at least one situation where that's not true: - pipelining handler seen a response .end and delivers a .head (which is done in `channelRead`) - a handler further down stream writes and flushes some response data - the flushes fail which leads to us draining the receive buffer - if the receive buffer contained more requests, the pipelining handler's `channelRead` is called again (recursively) The net result of that was that the new request parts from the receive buffer would now jump the queue and go through the channel pipeline next, before other already buffered messages. Modifications: made the pipelining handler buffer if a `channelRead` comes in from the pipeline and there is already at least one message buffered. Result: the ordering of the incoming messages should now be respected which is very important... --- .../NIOHTTP1/HTTPServerPipelineHandler.swift | 37 ++++++- ...HTTPServerPipelineHandlerTest+XCTest.swift | 1 + .../HTTPServerPipelineHandlerTest.swift | 100 ++++++++++++++++++ 3 files changed, 133 insertions(+), 5 deletions(-) diff --git a/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift b/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift index cf5ac800..3feafde6 100644 --- a/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift +++ b/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift @@ -63,7 +63,12 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler { public typealias OutboundIn = HTTPServerResponsePart public typealias OutboundOut = HTTPServerResponsePart - public init() { } + public init() { + debugOnly { + self.nextExpectedInboundMessage = .head + self.nextExpectedOutboundMessage = .head + } + } /// The state of the HTTP connection. private enum ConnectionState { @@ -140,15 +145,37 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler { case bodyOrEnd } - private var nextExpectedOutboundMessage = NextExpectedMessageType.head + // always `nil` in release builds, never `nil` in debug builds + private var nextExpectedInboundMessage: NextExpectedMessageType? + // always `nil` in release builds, never `nil` in debug builds + private var nextExpectedOutboundMessage: NextExpectedMessageType? public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { - if case .responseEndPending = self.state { + if self.eventBuffer.count != 0 || self.state == .responseEndPending { self.eventBuffer.append(.channelRead(data)) return + } else { + self.deliverOneMessage(ctx: ctx, data: data) + } + } + + private func deliverOneMessage(ctx: ChannelHandlerContext, data: NIOAny) { + let msg = self.unwrapInboundIn(data) + + debugOnly { + switch msg { + case .head: + assert(self.nextExpectedInboundMessage == .head) + self.nextExpectedInboundMessage = .bodyOrEnd + case .body: + assert(self.nextExpectedInboundMessage == .bodyOrEnd) + case .end: + assert(self.nextExpectedInboundMessage == .bodyOrEnd) + self.nextExpectedInboundMessage = .head + } } - if case .end = self.unwrapInboundIn(data) { + if case .end = msg { // New request is complete. We don't want any more data from now on. self.state.requestEndReceived() } @@ -241,7 +268,7 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler { switch event { case .channelRead(let read): - self.channelRead(ctx: ctx, data: read) + self.deliverOneMessage(ctx: ctx, data: read) deliveredRead = true case .error(let error): ctx.fireErrorCaught(error) diff --git a/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest+XCTest.swift b/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest+XCTest.swift index 69074dd1..2bbb1177 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest+XCTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest+XCTest.swift @@ -34,6 +34,7 @@ extension HTTPServerPipelineHandlerTest { ("testPipelineHandlerWillDeliverHalfCloseEarly", testPipelineHandlerWillDeliverHalfCloseEarly), ("testAReadIsNotIssuedWhenUnbufferingAHalfCloseAfterRequestComplete", testAReadIsNotIssuedWhenUnbufferingAHalfCloseAfterRequestComplete), ("testHalfCloseWhileWaitingForResponseIsPassedAlongIfNothingElseBuffered", testHalfCloseWhileWaitingForResponseIsPassedAlongIfNothingElseBuffered), + ("testRecursiveChannelReadInvocationsDoNotCauseIssues", testRecursiveChannelReadInvocationsDoNotCauseIssues), ] } } diff --git a/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift b/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift index e001f714..ce3475d1 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift @@ -45,6 +45,7 @@ private final class ReadRecorder: ChannelInboundHandler { func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { self.reads.append(.channelRead(self.unwrapInboundIn(data))) + ctx.fireChannelRead(data) } func userInboundEventTriggered(ctx: ChannelHandlerContext, event: Any) { @@ -356,4 +357,103 @@ class HTTPServerPipelineHandlerTest: XCTestCase { .channelRead(HTTPServerRequestPart.end(nil)), .halfClose]) } + + func testRecursiveChannelReadInvocationsDoNotCauseIssues() throws { + func makeRequestHead(uri: String) -> HTTPRequestHead { + var requestHead = HTTPRequestHead(version: .init(major: 1, minor: 1), method: .GET, uri: uri) + requestHead.headers.add(name: "Host", value: "example.com") + return requestHead + } + + class VerifyOrderHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + enum NextExpectedMessageType { + case head + case end + } + enum State { + case req1HeadExpected + case req1EndExpected + case req2HeadExpected + case req2EndExpected + case req3HeadExpected + case req3EndExpected + case reqBoomHeadExpected + case reqBoomEndExpected + + case done + } + var state: State = .req1HeadExpected + + func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { + let req = self.unwrapInboundIn(data) + switch req { + case .head(let head): + // except for "req_1", we always send the .end straight away + var sendEnd = true + switch head.uri { + case "/req_1": + XCTAssertEqual(.req1HeadExpected, self.state) + self.state = .req1EndExpected + // for req_1, we don't send the end straight away to force the others to be buffered + sendEnd = false + case "/req_2": + XCTAssertEqual(.req2HeadExpected, self.state) + self.state = .req2EndExpected + case "/req_3": + XCTAssertEqual(.req3HeadExpected, self.state) + self.state = .req3EndExpected + case "/req_boom": + XCTAssertEqual(.reqBoomHeadExpected, self.state) + self.state = .reqBoomEndExpected + default: + XCTFail("didn't expect \(head)") + } + ctx.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok))), promise: nil) + if sendEnd { + ctx.write(self.wrapOutboundOut(.end(nil)), promise: nil) + } + ctx.flush() + case .end: + switch self.state { + case .req1EndExpected: + self.state = .req2HeadExpected + case .req2EndExpected: + self.state = .req3HeadExpected + + // this will cause `channelRead` to be recursively called and we need to make sure everything then still works + try! (ctx.channel as! EmbeddedChannel).writeInbound(HTTPServerRequestPart.head(HTTPRequestHead(version: .init(major: 1, minor: 1), method: .GET, uri: "/req_boom"))) + try! (ctx.channel as! EmbeddedChannel).writeInbound(HTTPServerRequestPart.end(nil)) + case .req3EndExpected: + self.state = .reqBoomHeadExpected + case .reqBoomEndExpected: + self.state = .done + default: + XCTFail("illegal state for end: \(self.state)") + } + case .body: + XCTFail("we don't send any bodies") + } + } + } + + let handler = VerifyOrderHandler() + XCTAssertNoThrow(try channel.pipeline.add(handler: handler).wait()) + + for f in 1...3 { + XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(makeRequestHead(uri: "/req_\(f)")))) + XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) + } + + // now we should have delivered the first request, with the second and third buffered because req_1's .end + // doesn't get sent by the handler (instead we'll do that below) + XCTAssertEqual(.req2HeadExpected, handler.state) + + // finish 1st request, that will send through the 2nd one which will then write the 'req_boom' request + XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) + + XCTAssertEqual(.done, handler.state) + } }