Split client/server processing in OpenSSL

This commit is contained in:
Cory Benfield 2017-10-24 16:41:23 +01:00
parent e9c76088ac
commit 32c30074ea
6 changed files with 155 additions and 44 deletions

View File

@ -0,0 +1,30 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIO
/// A channel handler that wraps a channel in TLS using OpenSSL, or an
/// OpenSSL-compatible library. This handler can be used in channels that
/// are acting as the client in the TLS dialog. For server connections,
/// use the OpenSSLServerHandler.
public final class OpenSSLClientHandler: OpenSSLHandler {
public init(context: SSLContext) throws {
guard let connection = context.createConnection() else {
throw NIOOpenSSLError.unableToAllocateOpenSSLObject
}
connection.setConnectState()
super.init(connection: connection)
}
}

View File

@ -15,7 +15,15 @@
import NIO
import CNIOOpenSSL
public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandler {
/// The base class for all OpenSSL handlers. This class cannot actually be instantiated by
/// users directly: instead, users must select which mode they would like their handler to
/// operate in, client or server.
///
/// This class exists to deal with the reality that for almost the entirety of the lifetime
/// of a TLS connection there is no meaningful distinction between a server and a client.
/// For this reason almost the entirety of the implementation for the channel and server
/// handlers in OpenSSL is shared, in the form of this parent class.
public class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandler {
public typealias OutboundIn = ByteBuffer
public typealias OutboundOut = ByteBuffer
public typealias InboundIn = ByteBuffer
@ -29,49 +37,26 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
case closing
case closed
}
private let context: SSLContext
private var state: ConnectionState = .idle
private var connection: SSLConnection? = nil
private var connection: SSLConnection
private var bufferedWrites: MarkedCircularBuffer<BufferedEvent>
private var closePromise: Promise<Void>?
private var didDeliverData: Bool = false
public init (context: SSLContext) {
self.context = context
internal init (connection: SSLConnection) {
self.connection = connection
self.bufferedWrites = MarkedCircularBuffer(initialRingCapacity: 96) // 96 brings the total size of the buffer to just shy of one page
}
public func connect(ctx: ChannelHandlerContext, to address: SocketAddress, promise: Promise<Void>?) {
// This fires when we're asked to connect to a server. Necessarily if we're doing that then we're a
// client, and so we should set up our underlying OpenSSL Connection object to be a client. We don't
// bother starting the handshake now though: we have nowhere to write to!
assert(connection == nil)
self.connection = context.createConnection()
guard let connection = self.connection else {
promise?.fail(error: NIOOpenSSLError.unableToAllocateOpenSSLObject)
return
public func handlerAdded(ctx: ChannelHandlerContext) {
// If this channel is already active, immediately begin handshaking.
if ctx.channel!.isActive {
doHandshakeStep(ctx: ctx)
}
connection.setConnectState()
ctx.connect(to: address, promise: promise)
}
public func channelActive(ctx: ChannelHandlerContext) {
// This fires when the TCP connection is established. If we don't have a Connection object yet
// that means we're a server, so we should create it now and start the handshake.
// If we already have a connection we are a client, so we can just start handshaking.
if connection == nil {
self.connection = context.createConnection()
guard let connection = self.connection else {
ctx.fireErrorCaught(error: NIOOpenSSLError.unableToAllocateOpenSSLObject)
ctx.close(promise: nil)
return
}
connection.setAcceptState()
}
// We fire this a bit early, entirely on purpose. This is because
// in doHandshakeStep we may end up closing the channel again, and
// if we do we want to make sure that the channelInactive message received
@ -102,7 +87,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
var binaryData = unwrapInboundIn(data)
// The logic: feed the buffers, then take an action based on state.
connection!.consumeDataFromNetwork(&binaryData)
connection.consumeDataFromNetwork(&binaryData)
switch state {
case .handshaking:
@ -167,7 +152,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
}
private func doHandshakeStep(ctx: ChannelHandlerContext) {
let result = connection!.doHandshake()
let result = connection.doHandshake()
switch result {
case .incomplete:
@ -193,7 +178,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
}
private func doShutdownStep(ctx: ChannelHandlerContext) {
let result = connection!.doShutdown()
let result = connection.doShutdown()
switch result {
case .incomplete:
@ -215,7 +200,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
private func doDecodeData(ctx: ChannelHandlerContext) {
readLoop: while true {
let result = connection!.readDataFromNetwork(allocator: ctx.channel!.allocator)
let result = connection.readDataFromNetwork(allocator: ctx.channel!.allocator)
switch result {
case .complete(let buf):
@ -239,7 +224,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
private func writeDataToNetwork(ctx: ChannelHandlerContext, promise: Promise<Void>?) {
// There may be no data to write, in which case we can just exit early.
guard let dataToWrite = connection!.getDataForNetwork(allocator: ctx.channel!.allocator) else {
guard let dataToWrite = connection.getDataForNetwork(allocator: ctx.channel!.allocator) else {
promise?.succeed(result: ())
return
}
@ -305,7 +290,7 @@ extension OpenSSLHandler {
/// Given a byte buffer to encode, passes it to OpenSSL and handles the result.
func encodeWrite(buf: inout ByteBuffer, promise: Promise<Void>?) throws -> Bool {
let result = connection!.writeDataToNetwork(&buf)
let result = connection.writeDataToNetwork(&buf)
switch result {
case .complete:

View File

@ -0,0 +1,30 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIO
/// A channel handler that wraps a channel in TLS using OpenSSL, or an
/// OpenSSL-compatible library. This handler can be used in channels that
/// are acting as the server in the TLS dialog. For client connections,
/// use the OpenSSLClientHandler.
public final class OpenSSLServerHandler: OpenSSLHandler {
public init(context: SSLContext) throws {
guard let connection = context.createConnection() else {
throw NIOOpenSSLError.unableToAllocateOpenSSLObject
}
connection.setAcceptState()
super.init(connection: connection)
}
}

View File

@ -40,7 +40,7 @@ let bootstrap = ServerBootstrap(group: group)
// Set the handlers that are applied to the accepted channels.
.handler(childHandler: ChannelInitializer(initChannel: { channel in
return channel.pipeline.add(handler: OpenSSLHandler(context: sslContext)).then(callback: { v2 in
return channel.pipeline.add(handler: try! OpenSSLServerHandler(context: sslContext)).then(callback: { v2 in
return channel.pipeline.add(handler: EchoHandler())
})
}))

View File

@ -33,6 +33,7 @@ extension OpenSSLIntegrationTest {
("testCoalescedWrites", testCoalescedWrites),
("testCoalescedWritesWithFutures", testCoalescedWritesWithFutures),
("testImmediateCloseSatisfiesPromises", testImmediateCloseSatisfiesPromises),
("testAddingTlsToActiveChannelStillHandshakes", testAddingTlsToActiveChannelStillHandshakes),
]
}
}

View File

@ -150,11 +150,28 @@ public final class EventRecorderHandler<UserEventType>: ChannelInboundHandler wh
}
}
private class ChannelActiveWaiter: ChannelInboundHandler {
public typealias InboundIn = Any
private var activePromise: Promise<Void>
public init(promise: Promise<Void>) {
activePromise = promise
}
public func channelActive(ctx: ChannelHandlerContext) {
activePromise.succeed(result: ())
}
public func waitForChannelActive() throws {
try activePromise.futureResult.wait()
}
}
internal func serverTLSChannel(withContext: NIOOpenSSL.SSLContext, andHandlers: [ChannelHandler], onGroup: EventLoopGroup) throws -> Channel {
return try ServerBootstrap(group: onGroup)
.option(option: ChannelOptions.Socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.handler(childHandler: ChannelInitializer(initChannel: { channel in
return channel.pipeline.add(handler: OpenSSLHandler(context: withContext)).then(callback: { v2 in
return channel.pipeline.add(handler: try! OpenSSLServerHandler(context: withContext)).then(callback: { v2 in
let results = andHandlers.map { channel.pipeline.add(handler: $0) }
// NB: This assumes that the futures will always fire in order. This is not necessarily guaranteed
@ -175,7 +192,7 @@ internal func clientTLSChannel(withContext: NIOOpenSSL.SSLContext,
let results = preHandlers.map { channel.pipeline.add(handler: $0) }
return (results.last ?? channel.eventLoop.newSucceedFuture(result: ())).then(callback: { v2 in
return channel.pipeline.add(handler: OpenSSLHandler(context: withContext)).then(callback: { v2 in
return channel.pipeline.add(handler: try! OpenSSLClientHandler(context: withContext)).then(callback: { v2 in
let results = postHandlers.map { channel.pipeline.add(handler: $0) }
// NB: This assumes that the futures will always fire in order. This is not necessarily guaranteed
@ -491,7 +508,7 @@ class OpenSSLIntegrationTest: XCTestCase {
func testImmediateCloseSatisfiesPromises() throws {
let ctx = try configuredSSLContext()
let channel = EmbeddedChannel()
try channel.pipeline.add(handler: OpenSSLHandler(context: ctx)).wait()
try channel.pipeline.add(handler: OpenSSLClientHandler(context: ctx)).wait()
// Start by initiating the handshake.
try channel.connect(to: SocketAddress.unixDomainSocketAddress(path: "/tmp/doesntmatter")).wait()
@ -502,4 +519,52 @@ class OpenSSLIntegrationTest: XCTestCase {
XCTAssertTrue(closePromise.futureResult.fulfilled)
}
func testAddingTlsToActiveChannelStillHandshakes() throws {
let ctx = try configuredSSLContext()
let group = try MultiThreadedEventLoopGroup(numThreads: 1)
defer {
try! group.syncShutdownGracefully()
}
let recorderHandler: EventRecorderHandler<TLSUserEvent> = EventRecorderHandler()
let channelActiveWaiter = ChannelActiveWaiter(promise: group.next().newPromise())
let serverChannel = try serverTLSChannel(withContext: ctx,
andHandlers: [recorderHandler, SimpleEchoServer(), channelActiveWaiter],
onGroup: group)
defer {
_ = try! serverChannel.close().wait()
}
// Create a client channel without TLS in it, and connect it.
let readPromise: Promise<ByteBuffer> = group.next().newPromise()
let promiseOnReadHandler = PromiseOnReadHandler(promise: readPromise)
let clientChannel = try ClientBootstrap(group: group)
.handler(handler: promiseOnReadHandler)
.connect(to: serverChannel.localAddress!).wait()
defer {
_ = try! clientChannel.close().wait()
}
// Wait until the channel comes up, then confirm that no handshake has been
// received. This hardly proves much, but it's enough.
try channelActiveWaiter.waitForChannelActive()
try group.next().submit {
XCTAssertEqual(recorderHandler.events, [.Registered, .Active])
}.wait()
// Now, add the TLS handler to the pipeline.
try clientChannel.pipeline.add(name: nil, handler: OpenSSLClientHandler(context: ctx), first: true).wait()
var data = clientChannel.allocator.buffer(capacity: 1)
data.write(staticString: "x")
try clientChannel.writeAndFlush(data: IOData(data)).wait()
// The echo should come back without error.
_ = try readPromise.futureResult.wait()
// At this point the handshake should be complete.
try group.next().submit {
XCTAssertEqual(recorderHandler.events[..<3], [.Registered, .Active, .UserEvent(.handshakeCompleted)])
}.wait()
}
}