diff --git a/Sources/NIOOpenSSL/OpenSSLClientHandler.swift b/Sources/NIOOpenSSL/OpenSSLClientHandler.swift new file mode 100644 index 00000000..34f9da29 --- /dev/null +++ b/Sources/NIOOpenSSL/OpenSSLClientHandler.swift @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIO + +/// A channel handler that wraps a channel in TLS using OpenSSL, or an +/// OpenSSL-compatible library. This handler can be used in channels that +/// are acting as the client in the TLS dialog. For server connections, +/// use the OpenSSLServerHandler. +public final class OpenSSLClientHandler: OpenSSLHandler { + public init(context: SSLContext) throws { + guard let connection = context.createConnection() else { + throw NIOOpenSSLError.unableToAllocateOpenSSLObject + } + + connection.setConnectState() + super.init(connection: connection) + } +} diff --git a/Sources/NIOOpenSSL/OpenSSLHandler.swift b/Sources/NIOOpenSSL/OpenSSLHandler.swift index 9d7508f9..42ce8de1 100644 --- a/Sources/NIOOpenSSL/OpenSSLHandler.swift +++ b/Sources/NIOOpenSSL/OpenSSLHandler.swift @@ -15,7 +15,15 @@ import NIO import CNIOOpenSSL -public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandler { +/// The base class for all OpenSSL handlers. This class cannot actually be instantiated by +/// users directly: instead, users must select which mode they would like their handler to +/// operate in, client or server. +/// +/// This class exists to deal with the reality that for almost the entirety of the lifetime +/// of a TLS connection there is no meaningful distinction between a server and a client. +/// For this reason almost the entirety of the implementation for the channel and server +/// handlers in OpenSSL is shared, in the form of this parent class. +public class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandler { public typealias OutboundIn = ByteBuffer public typealias OutboundOut = ByteBuffer public typealias InboundIn = ByteBuffer @@ -29,49 +37,26 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle case closing case closed } - - private let context: SSLContext + private var state: ConnectionState = .idle - private var connection: SSLConnection? = nil + private var connection: SSLConnection private var bufferedWrites: MarkedCircularBuffer private var closePromise: Promise? private var didDeliverData: Bool = false - public init (context: SSLContext) { - self.context = context + internal init (connection: SSLConnection) { + self.connection = connection self.bufferedWrites = MarkedCircularBuffer(initialRingCapacity: 96) // 96 brings the total size of the buffer to just shy of one page } - - public func connect(ctx: ChannelHandlerContext, to address: SocketAddress, promise: Promise?) { - // This fires when we're asked to connect to a server. Necessarily if we're doing that then we're a - // client, and so we should set up our underlying OpenSSL Connection object to be a client. We don't - // bother starting the handshake now though: we have nowhere to write to! - assert(connection == nil) - self.connection = context.createConnection() - guard let connection = self.connection else { - promise?.fail(error: NIOOpenSSLError.unableToAllocateOpenSSLObject) - return + + public func handlerAdded(ctx: ChannelHandlerContext) { + // If this channel is already active, immediately begin handshaking. + if ctx.channel!.isActive { + doHandshakeStep(ctx: ctx) } - - connection.setConnectState() - ctx.connect(to: address, promise: promise) } - + public func channelActive(ctx: ChannelHandlerContext) { - // This fires when the TCP connection is established. If we don't have a Connection object yet - // that means we're a server, so we should create it now and start the handshake. - // If we already have a connection we are a client, so we can just start handshaking. - if connection == nil { - self.connection = context.createConnection() - guard let connection = self.connection else { - ctx.fireErrorCaught(error: NIOOpenSSLError.unableToAllocateOpenSSLObject) - ctx.close(promise: nil) - return - } - - connection.setAcceptState() - } - // We fire this a bit early, entirely on purpose. This is because // in doHandshakeStep we may end up closing the channel again, and // if we do we want to make sure that the channelInactive message received @@ -102,7 +87,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle var binaryData = unwrapInboundIn(data) // The logic: feed the buffers, then take an action based on state. - connection!.consumeDataFromNetwork(&binaryData) + connection.consumeDataFromNetwork(&binaryData) switch state { case .handshaking: @@ -167,7 +152,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle } private func doHandshakeStep(ctx: ChannelHandlerContext) { - let result = connection!.doHandshake() + let result = connection.doHandshake() switch result { case .incomplete: @@ -193,7 +178,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle } private func doShutdownStep(ctx: ChannelHandlerContext) { - let result = connection!.doShutdown() + let result = connection.doShutdown() switch result { case .incomplete: @@ -215,7 +200,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle private func doDecodeData(ctx: ChannelHandlerContext) { readLoop: while true { - let result = connection!.readDataFromNetwork(allocator: ctx.channel!.allocator) + let result = connection.readDataFromNetwork(allocator: ctx.channel!.allocator) switch result { case .complete(let buf): @@ -239,7 +224,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle private func writeDataToNetwork(ctx: ChannelHandlerContext, promise: Promise?) { // There may be no data to write, in which case we can just exit early. - guard let dataToWrite = connection!.getDataForNetwork(allocator: ctx.channel!.allocator) else { + guard let dataToWrite = connection.getDataForNetwork(allocator: ctx.channel!.allocator) else { promise?.succeed(result: ()) return } @@ -305,7 +290,7 @@ extension OpenSSLHandler { /// Given a byte buffer to encode, passes it to OpenSSL and handles the result. func encodeWrite(buf: inout ByteBuffer, promise: Promise?) throws -> Bool { - let result = connection!.writeDataToNetwork(&buf) + let result = connection.writeDataToNetwork(&buf) switch result { case .complete: diff --git a/Sources/NIOOpenSSL/OpenSSLServerHandler.swift b/Sources/NIOOpenSSL/OpenSSLServerHandler.swift new file mode 100644 index 00000000..e4d930e7 --- /dev/null +++ b/Sources/NIOOpenSSL/OpenSSLServerHandler.swift @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIO + +/// A channel handler that wraps a channel in TLS using OpenSSL, or an +/// OpenSSL-compatible library. This handler can be used in channels that +/// are acting as the server in the TLS dialog. For client connections, +/// use the OpenSSLClientHandler. +public final class OpenSSLServerHandler: OpenSSLHandler { + public init(context: SSLContext) throws { + guard let connection = context.createConnection() else { + throw NIOOpenSSLError.unableToAllocateOpenSSLObject + } + + connection.setAcceptState() + super.init(connection: connection) + } +} diff --git a/Sources/NIOTLSServer/main.swift b/Sources/NIOTLSServer/main.swift index 8863d318..cef35994 100644 --- a/Sources/NIOTLSServer/main.swift +++ b/Sources/NIOTLSServer/main.swift @@ -40,7 +40,7 @@ let bootstrap = ServerBootstrap(group: group) // Set the handlers that are applied to the accepted channels. .handler(childHandler: ChannelInitializer(initChannel: { channel in - return channel.pipeline.add(handler: OpenSSLHandler(context: sslContext)).then(callback: { v2 in + return channel.pipeline.add(handler: try! OpenSSLServerHandler(context: sslContext)).then(callback: { v2 in return channel.pipeline.add(handler: EchoHandler()) }) })) diff --git a/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest+XCTest.swift b/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest+XCTest.swift index c2240f86..1ffb59db 100644 --- a/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest+XCTest.swift +++ b/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest+XCTest.swift @@ -33,6 +33,7 @@ extension OpenSSLIntegrationTest { ("testCoalescedWrites", testCoalescedWrites), ("testCoalescedWritesWithFutures", testCoalescedWritesWithFutures), ("testImmediateCloseSatisfiesPromises", testImmediateCloseSatisfiesPromises), + ("testAddingTlsToActiveChannelStillHandshakes", testAddingTlsToActiveChannelStillHandshakes), ] } } diff --git a/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest.swift b/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest.swift index c8d16b1c..4869ecc5 100644 --- a/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest.swift +++ b/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest.swift @@ -150,11 +150,28 @@ public final class EventRecorderHandler: ChannelInboundHandler wh } } +private class ChannelActiveWaiter: ChannelInboundHandler { + public typealias InboundIn = Any + private var activePromise: Promise + + public init(promise: Promise) { + activePromise = promise + } + + public func channelActive(ctx: ChannelHandlerContext) { + activePromise.succeed(result: ()) + } + + public func waitForChannelActive() throws { + try activePromise.futureResult.wait() + } +} + internal func serverTLSChannel(withContext: NIOOpenSSL.SSLContext, andHandlers: [ChannelHandler], onGroup: EventLoopGroup) throws -> Channel { return try ServerBootstrap(group: onGroup) .option(option: ChannelOptions.Socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .handler(childHandler: ChannelInitializer(initChannel: { channel in - return channel.pipeline.add(handler: OpenSSLHandler(context: withContext)).then(callback: { v2 in + return channel.pipeline.add(handler: try! OpenSSLServerHandler(context: withContext)).then(callback: { v2 in let results = andHandlers.map { channel.pipeline.add(handler: $0) } // NB: This assumes that the futures will always fire in order. This is not necessarily guaranteed @@ -175,7 +192,7 @@ internal func clientTLSChannel(withContext: NIOOpenSSL.SSLContext, let results = preHandlers.map { channel.pipeline.add(handler: $0) } return (results.last ?? channel.eventLoop.newSucceedFuture(result: ())).then(callback: { v2 in - return channel.pipeline.add(handler: OpenSSLHandler(context: withContext)).then(callback: { v2 in + return channel.pipeline.add(handler: try! OpenSSLClientHandler(context: withContext)).then(callback: { v2 in let results = postHandlers.map { channel.pipeline.add(handler: $0) } // NB: This assumes that the futures will always fire in order. This is not necessarily guaranteed @@ -491,7 +508,7 @@ class OpenSSLIntegrationTest: XCTestCase { func testImmediateCloseSatisfiesPromises() throws { let ctx = try configuredSSLContext() let channel = EmbeddedChannel() - try channel.pipeline.add(handler: OpenSSLHandler(context: ctx)).wait() + try channel.pipeline.add(handler: OpenSSLClientHandler(context: ctx)).wait() // Start by initiating the handshake. try channel.connect(to: SocketAddress.unixDomainSocketAddress(path: "/tmp/doesntmatter")).wait() @@ -502,4 +519,52 @@ class OpenSSLIntegrationTest: XCTestCase { XCTAssertTrue(closePromise.futureResult.fulfilled) } + + func testAddingTlsToActiveChannelStillHandshakes() throws { + let ctx = try configuredSSLContext() + let group = try MultiThreadedEventLoopGroup(numThreads: 1) + defer { + try! group.syncShutdownGracefully() + } + + let recorderHandler: EventRecorderHandler = EventRecorderHandler() + let channelActiveWaiter = ChannelActiveWaiter(promise: group.next().newPromise()) + let serverChannel = try serverTLSChannel(withContext: ctx, + andHandlers: [recorderHandler, SimpleEchoServer(), channelActiveWaiter], + onGroup: group) + defer { + _ = try! serverChannel.close().wait() + } + + // Create a client channel without TLS in it, and connect it. + let readPromise: Promise = group.next().newPromise() + let promiseOnReadHandler = PromiseOnReadHandler(promise: readPromise) + let clientChannel = try ClientBootstrap(group: group) + .handler(handler: promiseOnReadHandler) + .connect(to: serverChannel.localAddress!).wait() + defer { + _ = try! clientChannel.close().wait() + } + + // Wait until the channel comes up, then confirm that no handshake has been + // received. This hardly proves much, but it's enough. + try channelActiveWaiter.waitForChannelActive() + try group.next().submit { + XCTAssertEqual(recorderHandler.events, [.Registered, .Active]) + }.wait() + + // Now, add the TLS handler to the pipeline. + try clientChannel.pipeline.add(name: nil, handler: OpenSSLClientHandler(context: ctx), first: true).wait() + var data = clientChannel.allocator.buffer(capacity: 1) + data.write(staticString: "x") + try clientChannel.writeAndFlush(data: IOData(data)).wait() + + // The echo should come back without error. + _ = try readPromise.futureResult.wait() + + // At this point the handshake should be complete. + try group.next().submit { + XCTAssertEqual(recorderHandler.events[..<3], [.Registered, .Active, .UserEvent(.handshakeCompleted)]) + }.wait() + } }