diff --git a/Sources/NIOHTTP1/HTTPDecoder.swift b/Sources/NIOHTTP1/HTTPDecoder.swift index 21096119..9cbd9f90 100644 --- a/Sources/NIOHTTP1/HTTPDecoder.swift +++ b/Sources/NIOHTTP1/HTTPDecoder.swift @@ -293,16 +293,24 @@ private class BetterHTTPParser { // does not meet the requirement of RFC 7230. This is an outstanding http_parser issue: // https://github.com/nodejs/http-parser/issues/251. As a result, we check for these status // codes and override http_parser's handling as well. - guard let method = self.requestHeads.popFirst()?.method else { + guard !self.requestHeads.isEmpty else { self.richerError = NIOHTTPDecoderError.unsolicitedResponse return .error(HPE_UNKNOWN) } - - if method == .HEAD || method == .CONNECT { - skipBody = true - } else if statusCode / 100 == 1 || // 1XX codes - statusCode == 204 || statusCode == 304 { + + if 100 <= statusCode && statusCode < 200 && statusCode != 101 { + // if the response status is in the range of 100..<200 but not 101 we don't want to + // pop the request method. The actual request head is expected with the next HTTP + // head. skipBody = true + } else { + let method = self.requestHeads.removeFirst().method + if method == .HEAD || method == .CONNECT { + skipBody = true + } else if statusCode / 100 == 1 || // 1XX codes + statusCode == 204 || statusCode == 304 { + skipBody = true + } } } @@ -473,15 +481,28 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega // the actual state private let parser: BetterHTTPParser private let leftOverBytesStrategy: RemoveAfterUpgradeStrategy + private let informationalResponseStrategy: NIOInformationalResponseStrategy private let kind: HTTPDecoderKind private var stopParsing = false // set on upgrade or HTTP version error + private var lastResponseHeaderWasInformational = false /// Creates a new instance of `HTTPDecoder`. /// /// - parameters: /// - leftOverBytesStrategy: The strategy to use when removing the decoder from the pipeline and an upgrade was, /// detected. Note that this does not affect what happens on EOF. - public init(leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes) { + public convenience init(leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes) { + self.init(leftOverBytesStrategy: leftOverBytesStrategy, informationalResponseStrategy: .drop) + } + + /// Creates a new instance of `HTTPDecoder`. + /// + /// - parameters: + /// - leftOverBytesStrategy: The strategy to use when removing the decoder from the pipeline and an upgrade was, + /// detected. Note that this does not affect what happens on EOF. + /// - informationalResponseStrategy: Should informational responses (like http status 100) be forwarded or dropped. + /// Default is `.drop`. This property is only respected when decoding responses. + public init(leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, informationalResponseStrategy: NIOInformationalResponseStrategy = .drop) { self.headers.reserveCapacity(16) if In.self == HTTPServerRequestPart.self { self.kind = .request @@ -492,6 +513,7 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega } self.parser = BetterHTTPParser(kind: kind) self.leftOverBytesStrategy = leftOverBytesStrategy + self.informationalResponseStrategy = informationalResponseStrategy } func didReceiveBody(_ bytes: UnsafeRawBufferPointer) { @@ -545,7 +567,7 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega method: http_method, statusCode: Int, keepAliveState: KeepAliveState) -> Bool { - let message: NIOAny + let message: NIOAny? guard versionMajor == 1 else { self.stopParsing = true @@ -561,16 +583,39 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega headers: HTTPHeaders(self.headers, keepAliveState: keepAliveState)) message = NIOAny(HTTPServerRequestPart.head(reqHead)) + + case .response where (100..<200).contains(statusCode) && statusCode != 101: + self.lastResponseHeaderWasInformational = true + switch self.informationalResponseStrategy.base { + case .forward: + let resHeadPart = HTTPClientResponsePart.head( + versionMajor: versionMajor, + versionMinor: versionMinor, + statusCode: statusCode, + keepAliveState: keepAliveState, + headers: self.headers + ) + message = NIOAny(resHeadPart) + case .drop: + message = nil + } + case .response: - let resHead: HTTPResponseHead = HTTPResponseHead(version: .init(major: versionMajor, minor: versionMinor), - status: .init(statusCode: statusCode), - headers: HTTPHeaders(self.headers, - keepAliveState: keepAliveState)) - message = NIOAny(HTTPClientResponsePart.head(resHead)) + self.lastResponseHeaderWasInformational = false + let resHeadPart = HTTPClientResponsePart.head( + versionMajor: versionMajor, + versionMinor: versionMinor, + statusCode: statusCode, + keepAliveState: keepAliveState, + headers: self.headers + ) + message = NIOAny(resHeadPart) } self.url = nil self.headers.removeAll(keepingCapacity: true) - self.context!.fireChannelRead(message) + if let message = message { + self.context!.fireChannelRead(message) + } self.isUpgrade = isUpgrade return true } @@ -582,7 +627,9 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega case .request: self.context!.fireChannelRead(NIOAny(HTTPServerRequestPart.end(trailers.map(HTTPHeaders.init)))) case .response: - self.context!.fireChannelRead(NIOAny(HTTPClientResponsePart.end(trailers.map(HTTPHeaders.init)))) + if !self.lastResponseHeaderWasInformational { + self.context!.fireChannelRead(NIOAny(HTTPClientResponsePart.end(trailers.map(HTTPHeaders.init)))) + } } self.stopParsing = self.isUpgrade! self.isUpgrade = nil @@ -660,6 +707,25 @@ public enum RemoveAfterUpgradeStrategy { case dropBytes } +/// Strategy to use when a HTTPDecoder receives an informational HTTP response (1xx except 101) +public struct NIOInformationalResponseStrategy: Hashable { + enum Base { + case drop + case forward + } + + var base: Base + private init(_ base: Base) { + self.base = base + } + + /// Drop the informational response and only forward the "real" response + public static let drop = Self(.drop) + /// Forward the informational response and then forward the "real" response. This will result in + /// multiple `head` before an `end` is emitted. + public static let forward = Self(.forward) +} + extension HTTPParserError { /// Create a `HTTPParserError` from an error returned by `http_parser`. /// @@ -828,3 +894,19 @@ extension NIOHTTPDecoderError: CustomDebugStringConvertible { return String(describing: self.baseError) } } + +extension HTTPClientResponsePart { + fileprivate static func head( + versionMajor: Int, + versionMinor: Int, + statusCode: Int, + keepAliveState: KeepAliveState, + headers: [(String, String)] + ) -> HTTPClientResponsePart { + HTTPClientResponsePart.head(HTTPResponseHead( + version: .init(major: versionMajor, minor: versionMinor), + status: .init(statusCode: statusCode), + headers: HTTPHeaders(headers, keepAliveState: keepAliveState) + )) + } +} diff --git a/Tests/NIOHTTP1Tests/HTTPDecoderLengthTest.swift b/Tests/NIOHTTP1Tests/HTTPDecoderLengthTest.swift index b0efc36e..efab49dc 100644 --- a/Tests/NIOHTTP1Tests/HTTPDecoderLengthTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPDecoderLengthTest.swift @@ -184,7 +184,8 @@ class HTTPDecoderLengthTest: XCTestCase { responseStatus: HTTPResponseStatus, responseFramingField: FramingField) throws { XCTAssertNoThrow(try channel.pipeline.addHandler(HTTPRequestEncoder()).wait()) - XCTAssertNoThrow(try channel.pipeline.addHandler(ByteToMessageHandler(HTTPResponseDecoder())).wait()) + let decoder = HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes, informationalResponseStrategy: .forward) + XCTAssertNoThrow(try channel.pipeline.addHandler(ByteToMessageHandler(decoder)).wait()) let handler = MessageEndHandler() XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait()) @@ -214,9 +215,18 @@ class HTTPDecoderLengthTest: XCTestCase { // We should have a response, no body, and immediately see EOF. XCTAssert(handler.seenHead) - XCTAssertFalse(handler.seenBody) - XCTAssert(handler.seenEnd) - + switch responseStatus.code { + case 100, 102..<200: + // If an informational response header is tested, we expect another "real" header to + // follow. For this reason, we don't expect an `.end` here. + XCTAssertFalse(handler.seenBody) + XCTAssertFalse(handler.seenEnd) + + default: + XCTAssertFalse(handler.seenBody) + XCTAssert(handler.seenEnd) + } + XCTAssertTrue(try channel.finish().isClean) } diff --git a/Tests/NIOHTTP1Tests/HTTPDecoderTest+XCTest.swift b/Tests/NIOHTTP1Tests/HTTPDecoderTest+XCTest.swift index 1ac3d544..a137abb0 100644 --- a/Tests/NIOHTTP1Tests/HTTPDecoderTest+XCTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPDecoderTest+XCTest.swift @@ -54,6 +54,8 @@ extension HTTPDecoderTest { ("testAppropriateErrorWhenReceivingUnsolicitedResponse", testAppropriateErrorWhenReceivingUnsolicitedResponse), ("testAppropriateErrorWhenReceivingUnsolicitedResponseDoesNotRecover", testAppropriateErrorWhenReceivingUnsolicitedResponseDoesNotRecover), ("testOneRequestTwoResponses", testOneRequestTwoResponses), + ("testForwardContinueThenResponse", testForwardContinueThenResponse), + ("testDropContinueThanForwardResponse", testDropContinueThanForwardResponse), ("testRefusesRequestSmugglingAttempt", testRefusesRequestSmugglingAttempt), ("testTrimsTrailingOWS", testTrimsTrailingOWS), ("testMassiveChunkDoesNotBufferAndGivesUsHoweverMuchIsAvailable", testMassiveChunkDoesNotBufferAndGivesUsHoweverMuchIsAvailable), diff --git a/Tests/NIOHTTP1Tests/HTTPDecoderTest.swift b/Tests/NIOHTTP1Tests/HTTPDecoderTest.swift index 386200f8..1c634b4d 100644 --- a/Tests/NIOHTTP1Tests/HTTPDecoderTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPDecoderTest.swift @@ -792,6 +792,60 @@ class HTTPDecoderTest: XCTestCase { XCTAssertEqual(["channelReadComplete", "write", "flush", "channelRead", "errorCaught"], eventCounter.allTriggeredEvents()) XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean)) } + + func testForwardContinueThenResponse() { + let eventCounter = EventCounterHandler() + let decoder = HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes, informationalResponseStrategy: .forward) + let responseDecoder = ByteToMessageHandler(decoder) + let channel = EmbeddedChannel(handler: responseDecoder) + XCTAssertNoThrow(try channel.pipeline.addHandler(eventCounter).wait()) + + let requestHead: HTTPClientRequestPart = .head(.init(version: .http1_1, method: .POST, uri: "/")) + XCTAssertNoThrow(try channel.writeOutbound(requestHead)) + var buffer = channel.allocator.buffer(capacity: 128) + buffer.writeString("HTTP/1.1 100 continue\r\n\r\nHTTP/1.1 200 ok\r\ncontent-length: 0\r\n\r\n") + XCTAssertNoThrow(try channel.writeInbound(buffer)) + + XCTAssertEqual(try channel.readInbound(as: HTTPClientResponsePart.self), .head(.init(version: .http1_1, status: .continue))) + XCTAssertEqual(try channel.readInbound(as: HTTPClientResponsePart.self), .head(.init(version: .http1_1, status: .ok, headers: ["content-length": "0"]))) + XCTAssertEqual(.end(nil), try channel.readInbound(as: HTTPClientResponsePart.self)) + XCTAssertNil(try channel.readInbound(as: HTTPClientResponsePart.self)) + XCTAssertNotNil(try channel.readOutbound()) + + XCTAssertEqual(1, eventCounter.writeCalls) + XCTAssertEqual(1, eventCounter.flushCalls) + XCTAssertEqual(3, eventCounter.channelReadCalls) // .head, .head & .end + XCTAssertEqual(1, eventCounter.channelReadCompleteCalls) + XCTAssertEqual(["channelReadComplete", "channelRead", "write", "flush"], eventCounter.allTriggeredEvents()) + XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean)) + } + + func testDropContinueThanForwardResponse() { + let eventCounter = EventCounterHandler() + let decoder = HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes, informationalResponseStrategy: .drop) + let responseDecoder = ByteToMessageHandler(decoder) + let channel = EmbeddedChannel(handler: responseDecoder) + XCTAssertNoThrow(try channel.pipeline.addHandler(eventCounter).wait()) + + let requestHead: HTTPClientRequestPart = .head(.init(version: .http1_1, method: .POST, uri: "/")) + XCTAssertNoThrow(try channel.writeOutbound(requestHead)) + var buffer = channel.allocator.buffer(capacity: 128) + buffer.writeString("HTTP/1.1 100 continue\r\n\r\nHTTP/1.1 200 ok\r\ncontent-length: 0\r\n\r\n") + XCTAssertNoThrow(try channel.writeInbound(buffer)) + + XCTAssertEqual(try channel.readInbound(as: HTTPClientResponsePart.self), .head(.init(version: .http1_1, status: .ok, headers: ["content-length": "0"]))) + XCTAssertEqual(.end(nil), try channel.readInbound(as: HTTPClientResponsePart.self)) + XCTAssertNil(try channel.readInbound(as: HTTPClientResponsePart.self)) + XCTAssertNotNil(try channel.readOutbound()) + + XCTAssertEqual(1, eventCounter.writeCalls) + XCTAssertEqual(1, eventCounter.flushCalls) + XCTAssertEqual(2, eventCounter.channelReadCalls) // .head & .end + XCTAssertEqual(1, eventCounter.channelReadCompleteCalls) + XCTAssertEqual(["channelReadComplete", "channelRead", "write", "flush"], eventCounter.allTriggeredEvents()) + XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean)) + } + func testRefusesRequestSmugglingAttempt() throws { XCTAssertNoThrow(try channel.pipeline.addHandler(ByteToMessageHandler(HTTPRequestDecoder())).wait())