// This source file is part of the SwiftNIO open source project
// Copyright (c) 2019-2021 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
// SPDX-License-Identifier: Apache-2.0
import XCTest
import NIOCore
import NIOEmbedded
import NIOHTTP1
@testable import NIOWebSocket
extension EmbeddedChannel {
fileprivate func readByteBufferOutputAsString() throws -> String? {
if let requestBuffer: ByteBuffer = try self.readOutbound() {
return requestBuffer.getString(at: 0, length: requestBuffer.readableBytes)
return nil
extension ChannelPipeline {
fileprivate func assertDoesNotContain<Handler: ChannelHandler>(handlerType: Handler.Type,
file: StaticString = #filePath,
line: UInt = #line) throws {
XCTAssertThrowsError(try self.context(handlerType: handlerType).wait(), file: (file), line: line) { error in
XCTAssertEqual(.notFound, error as? ChannelPipelineError)
fileprivate func assertContains<Handler: ChannelHandler>(handlerType: Handler.Type) throws {
do {
_ = try self.context(handlerType: handlerType).wait()
} catch ChannelPipelineError.notFound {
XCTFail("Did not find handler")
private func setUpClientChannel(clientHTTPHandler: RemovableChannelHandler,
clientUpgraders: [NIOHTTPClientProtocolUpgrader],
_ upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void) throws -> EmbeddedChannel {
let channel = EmbeddedChannel()
let config: NIOHTTPClientUpgradeConfiguration = (
upgraders: clientUpgraders,
completionHandler: { context in
channel.pipeline.removeHandler(clientHTTPHandler, promise: nil)
try channel.pipeline.addHTTPClientHandlers(withClientUpgrade: config).flatMap({
try channel.connect(to: SocketAddress(ipAddress: "", port: 0))
return channel
// A HTTP handler that will send an initial request which can be augmented by the upgrade handler.
private final class BasicHTTPHandler: ChannelInboundHandler, RemovableChannelHandler {
fileprivate typealias InboundIn = HTTPClientResponsePart
fileprivate typealias OutboundOut = HTTPClientRequestPart
fileprivate func channelActive(context: ChannelHandlerContext) {
// We are connected. It's time to send the message to the server to initialise the upgrade dance.
self.fireSendRequest(context: context)
// A HTTP handler that will send a request and then fail if it receives a response or an error.
// It can be used when there is a successful upgrade as the handler should be removed by the upgrader.
private final class ExplodingHTTPHandler: ChannelInboundHandler, RemovableChannelHandler {
fileprivate typealias InboundIn = HTTPClientResponsePart
fileprivate typealias OutboundOut = HTTPClientRequestPart
fileprivate func channelActive(context: ChannelHandlerContext) {
// We are connected. It's time to send the message to the server to initialise the upgrade dance.
self.fireSendRequest(context: context)
fileprivate func channelRead(context: ChannelHandlerContext, data: NIOAny) {
XCTFail("Received unexpected read")
fileprivate func errorCaught(context: ChannelHandlerContext, error: Error) {
XCTFail("Received unexpected error")
extension ChannelInboundHandler where OutboundOut == HTTPClientRequestPart {
fileprivate func fireSendRequest(context: ChannelHandlerContext) {
var headers = HTTPHeaders()
headers.add(name: "Content-Type", value: "text/plain; charset=utf-8")
headers.add(name: "Content-Length", value: "\(0)")
let requestHead = HTTPRequestHead(version: .http1_1,
method: .GET,
uri: "/",
headers: headers)
context.write(self.wrapOutboundOut(.head(requestHead)), promise: nil)
let emptyBuffer = context.channel.allocator.buffer(capacity: 0)
let body = HTTPClientRequestPart.body(.byteBuffer(emptyBuffer))
context.write(self.wrapOutboundOut(body), promise: nil)
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
private class WebSocketRecorderHandler: ChannelInboundHandler, ChannelOutboundHandler {
typealias OutboundIn = WebSocketFrame
typealias InboundIn = WebSocketFrame
typealias OutboundOut = WebSocketFrame
public var frames: [WebSocketFrame] = []
public var errors: [Error] = []
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let frame = self.unwrapInboundIn(data)
func errorCaught(context: ChannelHandlerContext, error: Error) {
class WebSocketClientEndToEndTests: XCTestCase {
private func basicRequest(path: String = "/") -> String {
return "GET \(path) HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0"
func testSimpleUpgradeSucceeds() throws {
var upgradeHandlerCallbackFired = false
let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM="
let responseKey = "yKEqitDFPE81FyIhKTm+ojBqigk="
let basicUpgrader = NIOWebSocketClientUpgrader(requestKey: requestKey,
upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in
// The process should kick-off independently by sending the upgrade request to the server.
let clientChannel = try setUpClientChannel(clientHTTPHandler: ExplodingHTTPHandler(),
clientUpgraders: [basicUpgrader]) { _ in
// This is called before the upgrader gets called.
upgradeHandlerCallbackFired = true
// Read the server request.
if let requestString = try clientChannel.readByteBufferOutputAsString() {
XCTAssertEqual(requestString, self.basicRequest() + "\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Key: \(requestKey)\r\nSec-WebSocket-Version: 13\r\n\r\n")
} else {
// Push the successful server response.
let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n"
XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response)))
// Once upgraded, validate the http pipeline has been removed.
XCTAssertNoThrow(try clientChannel.pipeline
.assertDoesNotContain(handlerType: HTTPRequestEncoder.self))
XCTAssertNoThrow(try clientChannel.pipeline
.assertDoesNotContain(handlerType: ByteToMessageHandler<HTTPResponseDecoder>.self))
XCTAssertNoThrow(try clientChannel.pipeline
.assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self))
// Check that the pipeline now has the correct websocket handlers added.
XCTAssertNoThrow(try clientChannel.pipeline
.assertContains(handlerType: WebSocketFrameEncoder.self))
XCTAssertNoThrow(try clientChannel.pipeline
.assertContains(handlerType: ByteToMessageHandler<WebSocketFrameDecoder>.self))
XCTAssertNoThrow(try clientChannel.pipeline
.assertContains(handlerType: WebSocketRecorderHandler.self))
// Close the pipeline.
XCTAssertNoThrow(try clientChannel.close().wait())
func testRejectUpgradeIfMissingAcceptKey() throws {
let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM="
let basicUpgrader = NIOWebSocketClientUpgrader(requestKey: requestKey,
upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in
// The process should kick-off independently by sending the upgrade request to the server.
let clientChannel = try setUpClientChannel(clientHTTPHandler: BasicHTTPHandler(),
clientUpgraders: [basicUpgrader]) { _ in
// Push the successful server response but with a missing accept key.
let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\n\r\n"
XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in
XCTAssertEqual(.upgraderDeniedUpgrade, error as? NIOHTTPClientUpgradeError)
// Close the pipeline.
XCTAssertNoThrow(try clientChannel.close().wait())
func testRejectUpgradeIfIncorrectAcceptKey() throws {
let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM="
let responseKey = "notACorrectKeyL1am=F1y=nn="
let basicUpgrader = NIOWebSocketClientUpgrader(requestKey: requestKey,
upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in
// The process should kick-off independently by sending the upgrade request to the server.
let clientChannel = try setUpClientChannel(clientHTTPHandler: BasicHTTPHandler(),
clientUpgraders: [basicUpgrader]) { _ in
// Push the successful server response but with an incorrect response key.
let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n"
XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in
XCTAssertEqual(.upgraderDeniedUpgrade, error as? NIOHTTPClientUpgradeError)
// Close the pipeline.
XCTAssertNoThrow(try clientChannel.close().wait())
func testRejectUpgradeIfNotWebsocket() throws {
let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM="
let responseKey = "yKEqitDFPE81FyIhKTm+ojBqigk="
let basicUpgrader = NIOWebSocketClientUpgrader(requestKey: requestKey,
upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in
// The process should kick-off independently by sending the upgrade request to the server.
let clientChannel = try setUpClientChannel(clientHTTPHandler: BasicHTTPHandler(),
clientUpgraders: [basicUpgrader]) { _ in
// Push the successful server response with an incorrect protocol.
let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: myProtocol\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n"
XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in
XCTAssertEqual(.responseProtocolNotFound, error as? NIOHTTPClientUpgradeError)
// Close the pipeline.
XCTAssertNoThrow(try clientChannel.close().wait())
private func runSuccessfulUpgrade() throws -> (EmbeddedChannel, WebSocketRecorderHandler) {
let handler = WebSocketRecorderHandler()
let basicUpgrader = NIOWebSocketClientUpgrader(requestKey: "OfS0wDaT5NoxF2gqm7Zj2YtetzM=",
upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in
let clientChannel = try setUpClientChannel(clientHTTPHandler: ExplodingHTTPHandler(),
clientUpgraders: [basicUpgrader]) { _ in
// Push the successful server response.
let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:yKEqitDFPE81FyIhKTm+ojBqigk=\r\n\r\n"
XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response)))
// We now have a successful upgrade, clear the output channels read to test the frames.
XCTAssertNoThrow(try clientChannel.readOutbound(as: ByteBuffer.self))
return (clientChannel, handler)
func testSendAFewFrames() throws {
let (clientChannel, _) = try self.runSuccessfulUpgrade()
// Send a frame or two, to confirm that the Websocket pipeline works.
let stringSentInFrame = "hello, world"
var data = clientChannel.allocator.buffer(capacity: 12)
let dataFrame = WebSocketFrame(fin: true, opcode: .binary, data: data)
XCTAssertNoThrow(try clientChannel.writeOutbound(dataFrame))
let outboundAsString = try clientChannel.readAllOutboundBuffers().allAsString()
// Frame number two coming up.
let stringSentInSecondFrame = "two"
var dataTwo = clientChannel.allocator.buffer(capacity: 3)
let pingFrame = WebSocketFrame(fin: true, opcode: .text, data: dataTwo)
XCTAssertNoThrow(try clientChannel.writeAndFlush(pingFrame).wait())
let pingAsString = try clientChannel.readAllOutboundBuffers().allAsString()
// Close the pipeline.
XCTAssertNoThrow(try clientChannel.close().wait())
private func encodeFrame(dataString: String, opcode: WebSocketOpcode) throws -> (WebSocketFrame, [UInt8]) {
let serverChannel = EmbeddedChannel()
defer {
XCTAssertNoThrow(try serverChannel.finish())
var buffer = serverChannel.allocator.buffer(capacity: 11)
let dataFrame = WebSocketFrame(fin: true, opcode: .binary, data: buffer)
try serverChannel.pipeline.addHandler(WebSocketFrameEncoder()).wait()
serverChannel.writeAndFlush(dataFrame, promise: nil)
return (dataFrame, try serverChannel.readAllOutboundBytes())
func testReceiveAFewFrames() throws {
let (clientChannel, recorder) = try self.runSuccessfulUpgrade()
defer {
XCTAssertNoThrow(try clientChannel.finish(acceptAlreadyClosed: true))
// Listen out for a frame or two, to confirm that the Websocket pipeline works.
let (binaryFrame, binaryFrameAsBytes) = try self.encodeFrame(dataString: "hello, back", opcode: .binary)
var data = clientChannel.allocator.buffer(capacity: binaryFrameAsBytes.count)
XCTAssertNoThrow(try clientChannel.writeInbound(data))
XCTAssertEqual(recorder.frames, [binaryFrame])
// Frame number two coming up.
let (textFrame, textFrameAsBytes) = try self.encodeFrame(dataString: "two", opcode: .text)
var dataTwo = clientChannel.allocator.buffer(capacity: textFrameAsBytes.count)
XCTAssertNoThrow(try clientChannel.writeInbound(dataTwo))
XCTAssertEqual([binaryFrame, textFrame], recorder.frames)
XCTAssertEqual(0, recorder.errors.count)
// Close the pipeline.
XCTAssertNoThrow(try clientChannel.close().wait())