swift-nio/Sources/NIOTestUtils/NIOHTTP1TestServer.swift

426 lines
16 KiB
Swift

//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2019-2021 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIOCore
import NIOPosix
import NIOHTTP1
import NIOConcurrencyHelpers
private final class BlockingQueue<Element> {
private let condition = ConditionLock(value: false)
private var buffer = CircularBuffer<Result<Element, Error>>()
public struct TimeoutError: Error {}
internal func append(_ element: Result<Element, Error>) {
self.condition.lock()
self.buffer.append(element)
self.condition.unlock(withValue: true)
}
internal var isEmpty: Bool {
self.condition.lock()
defer { self.condition.unlock() }
return self.buffer.isEmpty
}
internal func popFirst(deadline: NIODeadline) throws -> Element {
let secondsUntilDeath = deadline - NIODeadline.now()
guard self.condition.lock(whenValue: true,
timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000)) else {
throw TimeoutError()
}
let first = self.buffer.removeFirst()
self.condition.unlock(withValue: !self.buffer.isEmpty)
return try first.get()
}
}
extension BlockingQueue: @unchecked Sendable where Element: Sendable {}
private final class WebServerHandler: ChannelDuplexHandler {
typealias InboundIn = HTTPServerRequestPart
typealias OutboundIn = HTTPServerResponsePart
typealias OutboundOut = HTTPServerResponsePart
private let webServer: NIOHTTP1TestServer
init(webServer: NIOHTTP1TestServer) {
self.webServer = webServer
}
func errorCaught(context: ChannelHandlerContext, error: Error) {
self.webServer.pushError(error)
context.close(promise: nil)
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
self.webServer.pushChannelRead(self.unwrapInboundIn(data))
}
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
switch self.unwrapOutboundIn(data) {
case .head(var head):
head.headers.replaceOrAdd(name: "connection", value: "close")
head.headers.remove(name: "keep-alive")
context.write(self.wrapOutboundOut(.head(head)), promise: promise)
case .body:
context.write(data, promise: promise)
case .end:
context.write(data).map {
context.close(promise: nil)
}.cascade(to: promise)
}
}
}
private final class AggregateBodyHandler: ChannelInboundHandler {
typealias InboundIn = HTTPServerRequestPart
typealias InboundOut = HTTPServerRequestPart
var receivedSoFar: ByteBuffer? = nil
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
switch self.unwrapInboundIn(data) {
case .head:
context.fireChannelRead(data)
case .body(var buffer):
self.receivedSoFar.setOrWriteBuffer(&buffer)
case .end:
if let receivedSoFar = self.receivedSoFar {
context.fireChannelRead(self.wrapInboundOut(.body(receivedSoFar)))
}
context.fireChannelRead(data)
}
}
}
/// HTTP1 server that accepts and process only one request at a time.
/// This helps writing tests against a real server while keeping the ability to
/// write tests and assertions the same way we would if we were testing a
/// `ChannelHandler` in isolation.
/// `NIOHTTP1TestServer` enables writing test cases for HTTP1 clients that have
/// complex behaviours like client implementing a protocol where an high level
/// operation translates into several, possibly parallel, HTTP requests.
///
/// With `NIOHTTP1TestServer` we have:
/// - visibility on the `HTTPServerRequestPart`s received by the server;
/// - control over the `HTTPServerResponsePart`s send by the server.
///
/// The following code snippet shows an example test case where the client
/// under test sends a request to the server.
///
/// // Setup the test environment.
/// let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
/// let allocator = ByteBufferAllocator()
/// let testServer = NIOHTTP1TestServer(group: group)
/// defer {
/// XCTAssertNoThrow(try testServer.stop())
/// XCTAssertNoThrow(try group.syncShutdownGracefully())
/// }
///
/// // Use your library to send a request to the server.
/// let requestBody = "ping"
/// var requestComplete: EventLoopFuture<String>!
/// XCTAssertNoThrow(requestComplete = try sendRequestTo(
/// URL(string: "http://127.0.0.1:\(testServer.serverPort)/some-route")!,
/// body: requestBody))
///
/// // Assert the server received the expected request.
/// XCTAssertNoThrow(try testServer.receiveHeadAndVerify { head in
/// XCTAssertEqual(head, .init(version: .http1_1,
/// method: .GET,
/// uri: "/some-route",
/// headers: .init([
/// ("Content-Type", "text/plain; charset=utf-8"),
/// ("Content-Length", "4")])))
/// })
/// var requestBuffer = allocator.buffer(capacity: 128)
/// requestBuffer.writeString(requestBody)
/// XCTAssertNoThrow(try testServer.receiveBodyAndVerify { body in
/// XCTAssertEqual(body, requestBuffer)
/// })
/// XCTAssertNoThrow(try testServer.receiveEndAndVerify { trailers in
/// XCTAssertNil(trailers)
/// })
///
/// // Make the server send a response to the client.
/// let responseBody = "pong"
/// var responseBuffer = allocator.buffer(capacity: 128)
/// responseBuffer.writeString(responseBody)
/// XCTAssertNoThrow(try testServer.writeOutbound(.head(.init(version: .http1_1, status: .ok))))
/// XCTAssertNoThrow(try testServer.writeOutbound(.body(.byteBuffer(responseBuffer))))
/// XCTAssertNoThrow(try testServer.writeOutbound(.end(nil)))
///
/// // Assert that the client received the response from the server.
/// XCTAssertNoThrow(XCTAssertEqual(responseBody, try requestComplete.wait()))
public final class NIOHTTP1TestServer {
private let eventLoop: EventLoop
// all protected by eventLoop
private let inboundBuffer: BlockingQueue<HTTPServerRequestPart> = .init()
private var currentClientChannel: Channel? = nil
private var serverChannel: Channel! = nil
enum State {
case channelsAvailable(CircularBuffer<Channel>)
case waitingForChannel(EventLoopPromise<Void>)
case idle
case stopped
}
private var state: State = .idle
func handleChannels() {
self.eventLoop.assertInEventLoop()
let channel: Channel
switch self.state {
case .channelsAvailable(var channels):
channel = channels.removeFirst()
if channels.isEmpty {
self.state = .idle
} else {
self.state = .channelsAvailable(channels)
}
case .idle:
let promise = self.eventLoop.makePromise(of: Void.self)
promise.futureResult.whenSuccess {
self.handleChannels()
}
self.state = .waitingForChannel(promise)
return
case .waitingForChannel:
preconditionFailure("illegal state \(self.state)")
case .stopped:
return
}
assert(self.currentClientChannel == nil)
self.currentClientChannel = channel
channel.closeFuture.whenSuccess {
self.currentClientChannel = nil
self.handleChannels()
return
}
channel.pipeline.configureHTTPServerPipeline().flatMap {
channel.pipeline.addHandler(AggregateBodyHandler())
}.flatMap {
channel.pipeline.addHandler(WebServerHandler(webServer: self))
}.whenSuccess {
_ = channel.setOption(ChannelOptions.autoRead, value: true)
}
}
public init(group: EventLoopGroup) {
self.eventLoop = group.next()
self.serverChannel = try! ServerBootstrap(group: self.eventLoop)
.childChannelOption(ChannelOptions.autoRead, value: false)
.childChannelInitializer { channel in
switch self.state {
case .channelsAvailable(var channels):
channels.append(channel)
self.state = .channelsAvailable(channels)
case .waitingForChannel(let promise):
self.state = .channelsAvailable([channel])
promise.succeed(())
case .idle:
self.state = .channelsAvailable([channel])
case .stopped:
channel.close(promise: nil)
}
return channel.eventLoop.makeSucceededFuture(())
}
.bind(host: "127.0.0.1", port: 0)
.map { channel in
self.handleChannels()
return channel
}
.wait()
}
}
// MARK: - Public API for test driver
extension NIOHTTP1TestServer {
struct NonEmptyInboundBufferOnStop: Error {}
public func stop() throws {
assert(!self.eventLoop.inEventLoop)
try self.eventLoop.flatSubmit { () -> EventLoopFuture<Void> in
switch self.state {
case .channelsAvailable(let channels):
self.state = .stopped
channels.forEach {
$0.close(promise: nil)
}
case .waitingForChannel(let promise):
self.state = .stopped
promise.fail(ChannelError.ioOnClosedChannel)
case .idle:
self.state = .stopped
case .stopped:
preconditionFailure("double stopped NIOHTTP1TestServer")
}
return self.serverChannel.close().flatMapThrowing {
self.serverChannel = nil
guard self.inboundBuffer.isEmpty else {
throw NonEmptyInboundBufferOnStop()
}
}.always { _ in
self.currentClientChannel?.close(promise: nil)
}
}.wait()
}
public func readInbound(deadline: NIODeadline = .now() + .seconds(10)) throws -> HTTPServerRequestPart {
self.eventLoop.assertNotInEventLoop()
return try self.eventLoop.submit { () -> BlockingQueue<HTTPServerRequestPart> in
self.inboundBuffer
}.wait().popFirst(deadline: deadline)
}
public func writeOutbound(_ data: HTTPServerResponsePart) throws {
self.eventLoop.assertNotInEventLoop()
try self.eventLoop.flatSubmit { () -> EventLoopFuture<Void> in
if let channel = self.currentClientChannel {
return channel.writeAndFlush(data)
} else {
return self.eventLoop.makeFailedFuture(ChannelError.ioOnClosedChannel)
}
}.wait()
}
public var serverPort: Int {
self.eventLoop.assertNotInEventLoop()
return self.serverChannel!.localAddress!.port!
}
}
// All mutable state is protected through `eventLoop`
extension NIOHTTP1TestServer: @unchecked Sendable {}
// MARK: - API for HTTP server
extension NIOHTTP1TestServer {
fileprivate func pushChannelRead(_ state: HTTPServerRequestPart) {
self.eventLoop.assertInEventLoop()
self.inboundBuffer.append(.success(state))
}
fileprivate func pushError(_ error: Error) {
self.eventLoop.assertInEventLoop()
self.inboundBuffer.append(.failure(error))
}
}
extension NIOHTTP1TestServer {
/// Waits for a message part to be received and checks that it was a `.head` before returning
/// the `HTTPRequestHead` it contained.
///
/// - Parameters:
/// - deadline: The deadline by which a part must have been received.
/// - Throws: If the part was not a `.head` or nothing was read before the deadline.
/// - Returns: The `HTTPRequestHead` from the `.head`.
public func receiveHead(deadline: NIODeadline = .now() + .seconds(10)) throws -> HTTPRequestHead {
let part = try self.readInbound(deadline: deadline)
switch part {
case .head(let head):
return head
default:
throw NIOHTTP1TestServerError(reason: "Expected .head but got '\(part)'")
}
}
/// Waits for a message part to be received and checks that it was a `.head` before passing
/// it to the `verify` block.
///
/// - Parameters:
/// - deadline: The deadline by which a part must have been received.
/// - verify: A closure which can be used to verify the contents of the `HTTPRequestHead`.
/// - Throws: If the part was not a `.head` or nothing was read before the deadline.
public func receiveHeadAndVerify(deadline: NIODeadline = .now() + .seconds(10),
_ verify: (HTTPRequestHead) throws -> () = { _ in }) throws {
try verify(self.receiveHead(deadline: deadline))
}
/// Waits for a message part to be received and checks that it was a `.body` before returning
/// the `ByteBuffer` it contained.
///
/// - Parameters:
/// - deadline: The deadline by which a part must have been received.
/// - Throws: If the part was not a `.body` or nothing was read before the deadline.
/// - Returns: The `ByteBuffer` from the `.body`.
public func receiveBody(deadline: NIODeadline = .now() + .seconds(10)) throws -> ByteBuffer {
let part = try self.readInbound(deadline: deadline)
switch part {
case .body(let buffer):
return buffer
default:
throw NIOHTTP1TestServerError(reason: "Expected .body but got '\(part)'")
}
}
/// Waits for a message part to be received and checks that it was a `.body` before passing
/// it to the `verify` block.
///
/// - Parameters:
/// - deadline: The deadline by which a part must have been received.
/// - verify: A closure which can be used to verify the contents of the `ByteBuffer`.
/// - Throws: If the part was not a `.body` or nothing was read before the deadline.
public func receiveBodyAndVerify(deadline: NIODeadline = .now() + .seconds(10),
_ verify: (ByteBuffer) throws -> () = { _ in }) throws {
try verify(self.receiveBody(deadline: deadline))
}
/// Waits for a message part to be received and checks that it was a `.end` before returning
/// the `HTTPHeaders?` it contained.
///
/// - Parameters:
/// - deadline: The deadline by which a part must have been received.
/// - Throws: If the part was not a `.end` or nothing was read before the deadline.
/// - Returns: The `HTTPHeaders?` from the `.end`.
public func receiveEnd(deadline: NIODeadline = .now() + .seconds(10)) throws -> HTTPHeaders? {
let part = try self.readInbound(deadline: deadline)
switch part {
case .end(let trailers):
return trailers
default:
throw NIOHTTP1TestServerError(reason: "Expected .end but got '\(part)'")
}
}
/// Waits for a message part to be received and checks that it was a `.end` before passing
/// it to the `verify` block.
///
/// - Parameters:
/// - deadline: The deadline by which a part must have been received.
/// - verify: A closure which can be used to verify the contents of the `HTTPHeaders?`.
/// - Throws: If the part was not a `.end` or nothing was read before the deadline.
public func receiveEndAndVerify(deadline: NIODeadline = .now() + .seconds(10),
_ verify: (HTTPHeaders?) throws -> () = { _ in }) throws {
try verify(self.receiveEnd())
}
}
public struct NIOHTTP1TestServerError: Error, Hashable, CustomStringConvertible {
public var reason: String
public init(reason: String) {
self.reason = reason
}
public var description: String {
return self.reason
}
}