Split client/server processing in OpenSSL
This commit is contained in:
parent
e9c76088ac
commit
32c30074ea
|
@ -0,0 +1,30 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
import NIO
|
||||
|
||||
/// A channel handler that wraps a channel in TLS using OpenSSL, or an
|
||||
/// OpenSSL-compatible library. This handler can be used in channels that
|
||||
/// are acting as the client in the TLS dialog. For server connections,
|
||||
/// use the OpenSSLServerHandler.
|
||||
public final class OpenSSLClientHandler: OpenSSLHandler {
|
||||
public init(context: SSLContext) throws {
|
||||
guard let connection = context.createConnection() else {
|
||||
throw NIOOpenSSLError.unableToAllocateOpenSSLObject
|
||||
}
|
||||
|
||||
connection.setConnectState()
|
||||
super.init(connection: connection)
|
||||
}
|
||||
}
|
|
@ -15,7 +15,15 @@
|
|||
import NIO
|
||||
import CNIOOpenSSL
|
||||
|
||||
public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandler {
|
||||
/// The base class for all OpenSSL handlers. This class cannot actually be instantiated by
|
||||
/// users directly: instead, users must select which mode they would like their handler to
|
||||
/// operate in, client or server.
|
||||
///
|
||||
/// This class exists to deal with the reality that for almost the entirety of the lifetime
|
||||
/// of a TLS connection there is no meaningful distinction between a server and a client.
|
||||
/// For this reason almost the entirety of the implementation for the channel and server
|
||||
/// handlers in OpenSSL is shared, in the form of this parent class.
|
||||
public class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandler {
|
||||
public typealias OutboundIn = ByteBuffer
|
||||
public typealias OutboundOut = ByteBuffer
|
||||
public typealias InboundIn = ByteBuffer
|
||||
|
@ -29,49 +37,26 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
|
|||
case closing
|
||||
case closed
|
||||
}
|
||||
|
||||
private let context: SSLContext
|
||||
|
||||
private var state: ConnectionState = .idle
|
||||
private var connection: SSLConnection? = nil
|
||||
private var connection: SSLConnection
|
||||
private var bufferedWrites: MarkedCircularBuffer<BufferedEvent>
|
||||
private var closePromise: Promise<Void>?
|
||||
private var didDeliverData: Bool = false
|
||||
|
||||
public init (context: SSLContext) {
|
||||
self.context = context
|
||||
internal init (connection: SSLConnection) {
|
||||
self.connection = connection
|
||||
self.bufferedWrites = MarkedCircularBuffer(initialRingCapacity: 96) // 96 brings the total size of the buffer to just shy of one page
|
||||
}
|
||||
|
||||
public func connect(ctx: ChannelHandlerContext, to address: SocketAddress, promise: Promise<Void>?) {
|
||||
// This fires when we're asked to connect to a server. Necessarily if we're doing that then we're a
|
||||
// client, and so we should set up our underlying OpenSSL Connection object to be a client. We don't
|
||||
// bother starting the handshake now though: we have nowhere to write to!
|
||||
assert(connection == nil)
|
||||
self.connection = context.createConnection()
|
||||
guard let connection = self.connection else {
|
||||
promise?.fail(error: NIOOpenSSLError.unableToAllocateOpenSSLObject)
|
||||
return
|
||||
|
||||
public func handlerAdded(ctx: ChannelHandlerContext) {
|
||||
// If this channel is already active, immediately begin handshaking.
|
||||
if ctx.channel!.isActive {
|
||||
doHandshakeStep(ctx: ctx)
|
||||
}
|
||||
|
||||
connection.setConnectState()
|
||||
ctx.connect(to: address, promise: promise)
|
||||
}
|
||||
|
||||
|
||||
public func channelActive(ctx: ChannelHandlerContext) {
|
||||
// This fires when the TCP connection is established. If we don't have a Connection object yet
|
||||
// that means we're a server, so we should create it now and start the handshake.
|
||||
// If we already have a connection we are a client, so we can just start handshaking.
|
||||
if connection == nil {
|
||||
self.connection = context.createConnection()
|
||||
guard let connection = self.connection else {
|
||||
ctx.fireErrorCaught(error: NIOOpenSSLError.unableToAllocateOpenSSLObject)
|
||||
ctx.close(promise: nil)
|
||||
return
|
||||
}
|
||||
|
||||
connection.setAcceptState()
|
||||
}
|
||||
|
||||
// We fire this a bit early, entirely on purpose. This is because
|
||||
// in doHandshakeStep we may end up closing the channel again, and
|
||||
// if we do we want to make sure that the channelInactive message received
|
||||
|
@ -102,7 +87,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
|
|||
var binaryData = unwrapInboundIn(data)
|
||||
|
||||
// The logic: feed the buffers, then take an action based on state.
|
||||
connection!.consumeDataFromNetwork(&binaryData)
|
||||
connection.consumeDataFromNetwork(&binaryData)
|
||||
|
||||
switch state {
|
||||
case .handshaking:
|
||||
|
@ -167,7 +152,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
|
|||
}
|
||||
|
||||
private func doHandshakeStep(ctx: ChannelHandlerContext) {
|
||||
let result = connection!.doHandshake()
|
||||
let result = connection.doHandshake()
|
||||
|
||||
switch result {
|
||||
case .incomplete:
|
||||
|
@ -193,7 +178,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
|
|||
}
|
||||
|
||||
private func doShutdownStep(ctx: ChannelHandlerContext) {
|
||||
let result = connection!.doShutdown()
|
||||
let result = connection.doShutdown()
|
||||
|
||||
switch result {
|
||||
case .incomplete:
|
||||
|
@ -215,7 +200,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
|
|||
|
||||
private func doDecodeData(ctx: ChannelHandlerContext) {
|
||||
readLoop: while true {
|
||||
let result = connection!.readDataFromNetwork(allocator: ctx.channel!.allocator)
|
||||
let result = connection.readDataFromNetwork(allocator: ctx.channel!.allocator)
|
||||
|
||||
switch result {
|
||||
case .complete(let buf):
|
||||
|
@ -239,7 +224,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
|
|||
|
||||
private func writeDataToNetwork(ctx: ChannelHandlerContext, promise: Promise<Void>?) {
|
||||
// There may be no data to write, in which case we can just exit early.
|
||||
guard let dataToWrite = connection!.getDataForNetwork(allocator: ctx.channel!.allocator) else {
|
||||
guard let dataToWrite = connection.getDataForNetwork(allocator: ctx.channel!.allocator) else {
|
||||
promise?.succeed(result: ())
|
||||
return
|
||||
}
|
||||
|
@ -305,7 +290,7 @@ extension OpenSSLHandler {
|
|||
|
||||
/// Given a byte buffer to encode, passes it to OpenSSL and handles the result.
|
||||
func encodeWrite(buf: inout ByteBuffer, promise: Promise<Void>?) throws -> Bool {
|
||||
let result = connection!.writeDataToNetwork(&buf)
|
||||
let result = connection.writeDataToNetwork(&buf)
|
||||
|
||||
switch result {
|
||||
case .complete:
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This source file is part of the SwiftNIO open source project
|
||||
//
|
||||
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
|
||||
// Licensed under Apache License v2.0
|
||||
//
|
||||
// See LICENSE.txt for license information
|
||||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
|
||||
//
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
import NIO
|
||||
|
||||
/// A channel handler that wraps a channel in TLS using OpenSSL, or an
|
||||
/// OpenSSL-compatible library. This handler can be used in channels that
|
||||
/// are acting as the server in the TLS dialog. For client connections,
|
||||
/// use the OpenSSLClientHandler.
|
||||
public final class OpenSSLServerHandler: OpenSSLHandler {
|
||||
public init(context: SSLContext) throws {
|
||||
guard let connection = context.createConnection() else {
|
||||
throw NIOOpenSSLError.unableToAllocateOpenSSLObject
|
||||
}
|
||||
|
||||
connection.setAcceptState()
|
||||
super.init(connection: connection)
|
||||
}
|
||||
}
|
|
@ -40,7 +40,7 @@ let bootstrap = ServerBootstrap(group: group)
|
|||
|
||||
// Set the handlers that are applied to the accepted channels.
|
||||
.handler(childHandler: ChannelInitializer(initChannel: { channel in
|
||||
return channel.pipeline.add(handler: OpenSSLHandler(context: sslContext)).then(callback: { v2 in
|
||||
return channel.pipeline.add(handler: try! OpenSSLServerHandler(context: sslContext)).then(callback: { v2 in
|
||||
return channel.pipeline.add(handler: EchoHandler())
|
||||
})
|
||||
}))
|
||||
|
|
|
@ -33,6 +33,7 @@ extension OpenSSLIntegrationTest {
|
|||
("testCoalescedWrites", testCoalescedWrites),
|
||||
("testCoalescedWritesWithFutures", testCoalescedWritesWithFutures),
|
||||
("testImmediateCloseSatisfiesPromises", testImmediateCloseSatisfiesPromises),
|
||||
("testAddingTlsToActiveChannelStillHandshakes", testAddingTlsToActiveChannelStillHandshakes),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -150,11 +150,28 @@ public final class EventRecorderHandler<UserEventType>: ChannelInboundHandler wh
|
|||
}
|
||||
}
|
||||
|
||||
private class ChannelActiveWaiter: ChannelInboundHandler {
|
||||
public typealias InboundIn = Any
|
||||
private var activePromise: Promise<Void>
|
||||
|
||||
public init(promise: Promise<Void>) {
|
||||
activePromise = promise
|
||||
}
|
||||
|
||||
public func channelActive(ctx: ChannelHandlerContext) {
|
||||
activePromise.succeed(result: ())
|
||||
}
|
||||
|
||||
public func waitForChannelActive() throws {
|
||||
try activePromise.futureResult.wait()
|
||||
}
|
||||
}
|
||||
|
||||
internal func serverTLSChannel(withContext: NIOOpenSSL.SSLContext, andHandlers: [ChannelHandler], onGroup: EventLoopGroup) throws -> Channel {
|
||||
return try ServerBootstrap(group: onGroup)
|
||||
.option(option: ChannelOptions.Socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
|
||||
.handler(childHandler: ChannelInitializer(initChannel: { channel in
|
||||
return channel.pipeline.add(handler: OpenSSLHandler(context: withContext)).then(callback: { v2 in
|
||||
return channel.pipeline.add(handler: try! OpenSSLServerHandler(context: withContext)).then(callback: { v2 in
|
||||
let results = andHandlers.map { channel.pipeline.add(handler: $0) }
|
||||
|
||||
// NB: This assumes that the futures will always fire in order. This is not necessarily guaranteed
|
||||
|
@ -175,7 +192,7 @@ internal func clientTLSChannel(withContext: NIOOpenSSL.SSLContext,
|
|||
let results = preHandlers.map { channel.pipeline.add(handler: $0) }
|
||||
|
||||
return (results.last ?? channel.eventLoop.newSucceedFuture(result: ())).then(callback: { v2 in
|
||||
return channel.pipeline.add(handler: OpenSSLHandler(context: withContext)).then(callback: { v2 in
|
||||
return channel.pipeline.add(handler: try! OpenSSLClientHandler(context: withContext)).then(callback: { v2 in
|
||||
let results = postHandlers.map { channel.pipeline.add(handler: $0) }
|
||||
|
||||
// NB: This assumes that the futures will always fire in order. This is not necessarily guaranteed
|
||||
|
@ -491,7 +508,7 @@ class OpenSSLIntegrationTest: XCTestCase {
|
|||
func testImmediateCloseSatisfiesPromises() throws {
|
||||
let ctx = try configuredSSLContext()
|
||||
let channel = EmbeddedChannel()
|
||||
try channel.pipeline.add(handler: OpenSSLHandler(context: ctx)).wait()
|
||||
try channel.pipeline.add(handler: OpenSSLClientHandler(context: ctx)).wait()
|
||||
|
||||
// Start by initiating the handshake.
|
||||
try channel.connect(to: SocketAddress.unixDomainSocketAddress(path: "/tmp/doesntmatter")).wait()
|
||||
|
@ -502,4 +519,52 @@ class OpenSSLIntegrationTest: XCTestCase {
|
|||
|
||||
XCTAssertTrue(closePromise.futureResult.fulfilled)
|
||||
}
|
||||
|
||||
func testAddingTlsToActiveChannelStillHandshakes() throws {
|
||||
let ctx = try configuredSSLContext()
|
||||
let group = try MultiThreadedEventLoopGroup(numThreads: 1)
|
||||
defer {
|
||||
try! group.syncShutdownGracefully()
|
||||
}
|
||||
|
||||
let recorderHandler: EventRecorderHandler<TLSUserEvent> = EventRecorderHandler()
|
||||
let channelActiveWaiter = ChannelActiveWaiter(promise: group.next().newPromise())
|
||||
let serverChannel = try serverTLSChannel(withContext: ctx,
|
||||
andHandlers: [recorderHandler, SimpleEchoServer(), channelActiveWaiter],
|
||||
onGroup: group)
|
||||
defer {
|
||||
_ = try! serverChannel.close().wait()
|
||||
}
|
||||
|
||||
// Create a client channel without TLS in it, and connect it.
|
||||
let readPromise: Promise<ByteBuffer> = group.next().newPromise()
|
||||
let promiseOnReadHandler = PromiseOnReadHandler(promise: readPromise)
|
||||
let clientChannel = try ClientBootstrap(group: group)
|
||||
.handler(handler: promiseOnReadHandler)
|
||||
.connect(to: serverChannel.localAddress!).wait()
|
||||
defer {
|
||||
_ = try! clientChannel.close().wait()
|
||||
}
|
||||
|
||||
// Wait until the channel comes up, then confirm that no handshake has been
|
||||
// received. This hardly proves much, but it's enough.
|
||||
try channelActiveWaiter.waitForChannelActive()
|
||||
try group.next().submit {
|
||||
XCTAssertEqual(recorderHandler.events, [.Registered, .Active])
|
||||
}.wait()
|
||||
|
||||
// Now, add the TLS handler to the pipeline.
|
||||
try clientChannel.pipeline.add(name: nil, handler: OpenSSLClientHandler(context: ctx), first: true).wait()
|
||||
var data = clientChannel.allocator.buffer(capacity: 1)
|
||||
data.write(staticString: "x")
|
||||
try clientChannel.writeAndFlush(data: IOData(data)).wait()
|
||||
|
||||
// The echo should come back without error.
|
||||
_ = try readPromise.futureResult.wait()
|
||||
|
||||
// At this point the handshake should be complete.
|
||||
try group.next().submit {
|
||||
XCTAssertEqual(recorderHandler.events[..<3], [.Registered, .Active, .UserEvent(.handshakeCompleted)])
|
||||
}.wait()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue