271 lines
10 KiB
Swift
271 lines
10 KiB
Swift
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This source file is part of the SwiftNIO open source project
|
|
//
|
|
// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors
|
|
// Licensed under Apache License v2.0
|
|
//
|
|
// See LICENSE.txt for license information
|
|
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
|
//
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
import XCTest
|
|
@testable import NIOCore
|
|
import NIOEmbedded
|
|
@testable import NIOHTTP1
|
|
|
|
private final class TestChannelInboundHandler: ChannelInboundHandler {
|
|
public typealias InboundIn = HTTPServerRequestPart
|
|
public typealias InboundOut = HTTPServerRequestPart
|
|
|
|
private let body: (HTTPServerRequestPart) -> HTTPServerRequestPart
|
|
|
|
init(_ body: @escaping (HTTPServerRequestPart) -> HTTPServerRequestPart) {
|
|
self.body = body
|
|
}
|
|
|
|
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
|
context.fireChannelRead(self.wrapInboundOut(self.body(self.unwrapInboundIn(data))))
|
|
}
|
|
}
|
|
|
|
|
|
class HTTPTest: XCTestCase {
|
|
|
|
func checkHTTPRequest(_ expected: HTTPRequestHead, body: String? = nil, trailers: HTTPHeaders? = nil) throws {
|
|
try checkHTTPRequests([expected], body: body, trailers: trailers)
|
|
}
|
|
|
|
func checkHTTPRequests(_ expecteds: [HTTPRequestHead], body: String? = nil, trailers: HTTPHeaders? = nil) throws {
|
|
func httpRequestStrForRequest(_ req: HTTPRequestHead) -> String {
|
|
var s = "\(req.method) \(req.uri) HTTP/\(req.version.major).\(req.version.minor)\r\n"
|
|
for (k, v) in req.headers {
|
|
s += "\(k): \(v)\r\n"
|
|
}
|
|
if trailers != nil {
|
|
s += "Transfer-Encoding: chunked\r\n"
|
|
s += "\r\n"
|
|
if let body = body {
|
|
s += String(body.utf8.count, radix: 16)
|
|
s += "\r\n"
|
|
s += body
|
|
s += "\r\n"
|
|
}
|
|
s += "0\r\n"
|
|
if let trailers = trailers {
|
|
for (k, v) in trailers {
|
|
s += "\(k): \(v)\r\n"
|
|
}
|
|
}
|
|
s += "\r\n"
|
|
} else if let body = body {
|
|
let bodyData = body.data(using: .utf8)!
|
|
s += "Content-Length: \(bodyData.count)\r\n"
|
|
s += "\r\n"
|
|
s += body
|
|
} else {
|
|
s += "\r\n"
|
|
}
|
|
return s
|
|
}
|
|
|
|
func sendAndCheckRequests(_ expecteds: [HTTPRequestHead], body: String?, trailers: HTTPHeaders?, sendStrategy: (String, EmbeddedChannel) -> EventLoopFuture<Void>) throws -> String? {
|
|
var step = 0
|
|
var index = 0
|
|
let channel = EmbeddedChannel()
|
|
defer {
|
|
XCTAssertNoThrow(try channel.finish())
|
|
}
|
|
try channel.pipeline.addHandler(ByteToMessageHandler(HTTPRequestDecoder())).wait()
|
|
var bodyData: [UInt8]? = nil
|
|
var allBodyDatas: [[UInt8]] = []
|
|
try channel.pipeline.addHandler(TestChannelInboundHandler { reqPart in
|
|
switch reqPart {
|
|
case .head(var req):
|
|
XCTAssertEqual((index * 2), step)
|
|
req.headers.remove(name: "Content-Length")
|
|
req.headers.remove(name: "Transfer-Encoding")
|
|
XCTAssertEqual(expecteds[index], req)
|
|
step += 1
|
|
case .body(var buffer):
|
|
if bodyData == nil {
|
|
bodyData = buffer.readBytes(length: buffer.readableBytes)!
|
|
} else {
|
|
bodyData!.append(contentsOf: buffer.readBytes(length: buffer.readableBytes)!)
|
|
}
|
|
case .end(let receivedTrailers):
|
|
XCTAssertEqual(trailers, receivedTrailers)
|
|
step += 1
|
|
XCTAssertEqual(((index + 1) * 2), step)
|
|
}
|
|
return reqPart
|
|
}).wait()
|
|
|
|
var writeFutures: [EventLoopFuture<Void>] = []
|
|
for expected in expecteds {
|
|
writeFutures.append(sendStrategy(httpRequestStrForRequest(expected), channel))
|
|
index += 1
|
|
if let bodyData = bodyData {
|
|
allBodyDatas.append(bodyData)
|
|
}
|
|
bodyData = nil
|
|
}
|
|
channel.pipeline.flush()
|
|
XCTAssertNoThrow(try EventLoopFuture.andAllSucceed(writeFutures, on: channel.eventLoop).wait())
|
|
XCTAssertEqual(2 * expecteds.count, step)
|
|
|
|
if body != nil {
|
|
XCTAssertGreaterThan(allBodyDatas.count, 0)
|
|
let firstBodyData = allBodyDatas[0]
|
|
for bodyData in allBodyDatas {
|
|
XCTAssertEqual(firstBodyData, bodyData)
|
|
}
|
|
return String(decoding: firstBodyData, as: Unicode.UTF8.self)
|
|
} else {
|
|
XCTAssertEqual(0, allBodyDatas.count, "left with \(allBodyDatas)")
|
|
return nil
|
|
}
|
|
}
|
|
|
|
/* send all bytes in one go */
|
|
let bd1 = try sendAndCheckRequests(expecteds, body: body, trailers: trailers, sendStrategy: { (reqString, chan) in
|
|
var buf = chan.allocator.buffer(capacity: 1024)
|
|
buf.writeString(reqString)
|
|
return chan.eventLoop.makeSucceededFuture(()).flatMapThrowing {
|
|
try chan.writeInbound(buf)
|
|
}
|
|
})
|
|
|
|
/* send the bytes one by one */
|
|
let bd2 = try sendAndCheckRequests(expecteds, body: body, trailers: trailers, sendStrategy: { (reqString, chan) in
|
|
var writeFutures: [EventLoopFuture<Void>] = []
|
|
for c in reqString {
|
|
var buf = chan.allocator.buffer(capacity: 1024)
|
|
|
|
buf.writeString("\(c)")
|
|
writeFutures.append(chan.eventLoop.makeSucceededFuture(()).flatMapThrowing { [buf] in
|
|
try chan.writeInbound(buf)
|
|
})
|
|
}
|
|
return EventLoopFuture.andAllSucceed(writeFutures, on: chan.eventLoop)
|
|
})
|
|
|
|
XCTAssertEqual(bd1, bd2)
|
|
XCTAssertEqual(body, bd1)
|
|
}
|
|
|
|
func testHTTPSimpleNoHeaders() throws {
|
|
try checkHTTPRequest(HTTPRequestHead(version: .http1_1, method: .GET, uri: "/"))
|
|
}
|
|
|
|
func testHTTPSimple1Header() throws {
|
|
var req = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/hello/world")
|
|
req.headers.add(name: "foo", value: "bar")
|
|
try checkHTTPRequest(req)
|
|
}
|
|
|
|
func testHTTPSimpleSomeHeader() throws {
|
|
var req = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/foo/bar/buz?qux=quux")
|
|
req.headers.add(name: "foo", value: "bar")
|
|
req.headers.add(name: "qux", value: "quuux")
|
|
try checkHTTPRequest(req)
|
|
}
|
|
|
|
func testHTTPPipelining() throws {
|
|
var req1 = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/foo/bar/buz?qux=quux")
|
|
req1.headers.add(name: "foo", value: "bar")
|
|
req1.headers.add(name: "qux", value: "quuux")
|
|
var req2 = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/")
|
|
req2.headers.add(name: "a", value: "b")
|
|
req2.headers.add(name: "C", value: "D")
|
|
|
|
try checkHTTPRequests([req1, req2])
|
|
try checkHTTPRequests(Array(repeating: req1, count: 10))
|
|
}
|
|
|
|
func testHTTPBody() throws {
|
|
try checkHTTPRequest(HTTPRequestHead(version: .http1_1, method: .GET, uri: "/"),
|
|
body: "hello world")
|
|
}
|
|
|
|
func test1ByteHTTPBody() throws {
|
|
try checkHTTPRequest(HTTPRequestHead(version: .http1_1, method: .GET, uri: "/"),
|
|
body: "1")
|
|
}
|
|
|
|
func testHTTPPipeliningWithBody() throws {
|
|
try checkHTTPRequests(Array(repeating: HTTPRequestHead(version: .http1_1,
|
|
method: .GET, uri: "/"),
|
|
count: 20),
|
|
body: "1")
|
|
}
|
|
|
|
func testChunkedBody() throws {
|
|
var trailers = HTTPHeaders()
|
|
trailers.add(name: "X-Key", value: "X-Value")
|
|
trailers.add(name: "Something", value: "Else")
|
|
try checkHTTPRequest(HTTPRequestHead(version: .http1_1, method: .POST, uri: "/"), body: "100", trailers: trailers)
|
|
}
|
|
|
|
func testHTTPRequestHeadCoWWorks() throws {
|
|
let headers = HTTPHeaders([("foo", "bar")])
|
|
var httpReq = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/uri")
|
|
httpReq.headers = headers
|
|
|
|
var modVersion = httpReq
|
|
modVersion.version = .http2
|
|
XCTAssertEqual(.http1_1, httpReq.version)
|
|
XCTAssertEqual(.http2, modVersion.version)
|
|
|
|
var modMethod = httpReq
|
|
modMethod.method = .POST
|
|
XCTAssertEqual(.GET, httpReq.method)
|
|
XCTAssertEqual(.POST, modMethod.method)
|
|
|
|
var modURI = httpReq
|
|
modURI.uri = "/changed"
|
|
XCTAssertEqual("/uri", httpReq.uri)
|
|
XCTAssertEqual("/changed", modURI.uri)
|
|
|
|
var modHeaders = httpReq
|
|
modHeaders.headers.add(name: "qux", value: "quux")
|
|
XCTAssertEqual(httpReq.headers, headers)
|
|
XCTAssertNotEqual(httpReq, modHeaders)
|
|
modHeaders.headers.remove(name: "foo")
|
|
XCTAssertEqual(httpReq.headers, headers)
|
|
XCTAssertNotEqual(httpReq, modHeaders)
|
|
modHeaders.headers.remove(name: "qux")
|
|
modHeaders.headers.add(name: "foo", value: "bar")
|
|
XCTAssertEqual(httpReq, modHeaders)
|
|
}
|
|
|
|
func testHTTPResponseHeadCoWWorks() throws {
|
|
let headers = HTTPHeaders([("foo", "bar")])
|
|
let httpRes = HTTPResponseHead(version: .http1_1, status: .ok, headers: headers)
|
|
|
|
var modVersion = httpRes
|
|
modVersion.version = .http2
|
|
XCTAssertEqual(.http1_1, httpRes.version)
|
|
XCTAssertEqual(.http2, modVersion.version)
|
|
|
|
var modStatus = httpRes
|
|
modStatus.status = .notFound
|
|
XCTAssertEqual(.ok, httpRes.status)
|
|
XCTAssertEqual(.notFound, modStatus.status)
|
|
|
|
var modHeaders = httpRes
|
|
modHeaders.headers.add(name: "qux", value: "quux")
|
|
XCTAssertEqual(httpRes.headers, headers)
|
|
XCTAssertNotEqual(httpRes, modHeaders)
|
|
modHeaders.headers.remove(name: "foo")
|
|
XCTAssertEqual(httpRes.headers, headers)
|
|
XCTAssertNotEqual(httpRes, modHeaders)
|
|
modHeaders.headers.remove(name: "qux")
|
|
modHeaders.headers.add(name: "foo", value: "bar")
|
|
XCTAssertEqual(httpRes, modHeaders)
|
|
}
|
|
}
|