diff --git a/Sources/NIOOpenSSL/OpenSSLHandler.swift b/Sources/NIOOpenSSL/OpenSSLHandler.swift index e95a546f..cfaf53b5 100644 --- a/Sources/NIOOpenSSL/OpenSSLHandler.swift +++ b/Sources/NIOOpenSSL/OpenSSLHandler.swift @@ -21,9 +21,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle public typealias InboundIn = ByteBuffer public typealias InboundOut = ByteBuffer public typealias InboundUserEventOut = TLSUserEvent - - private typealias BufferedWrite = (data: ByteBuffer, promise: Promise?) - + private enum ConnectionState { case idle case handshaking @@ -35,12 +33,13 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle private let context: SSLContext private var state: ConnectionState = .idle private var connection: SSLConnection? = nil - private var bufferedWrites: [BufferedWrite] = [] + private var bufferedWrites: MarkedCircularBuffer private var closePromise: Promise? private var didDeliverData: Bool = false public init (context: SSLContext) { self.context = context + self.bufferedWrites = MarkedCircularBuffer(initialRingCapacity: 96) // 96 brings the total size of the buffer to just shy of one page } public func connect(ctx: ChannelHandlerContext, to address: SocketAddress, promise: Promise?) { @@ -136,8 +135,12 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle } public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise?) { - var binaryData = unwrapOutboundIn(data) - doEncodeData(data: &binaryData, ctx: ctx, promise: promise) + bufferWrite(data: unwrapOutboundIn(data), promise: promise) + } + + public func flush(ctx: ChannelHandlerContext, promise: Promise?) { + bufferFlush(promise: promise) + doUnbufferWrites(ctx: ctx) } public func close(ctx: ChannelHandlerContext, promise: Promise?) { @@ -234,82 +237,13 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle } } - private func doEncodeData(data: inout ByteBuffer, ctx: ChannelHandlerContext, promise: Promise?) { - if state == .closing || state == .closed { - // We're either shutting down or shut down: no further encoded data is allowed. - promise?.fail(error: NIOOpenSSLError.writeDuringTLSShutdown) - return - } - - let result = connection!.writeDataToNetwork(&data) - switch result { - case .complete: - writeDataToNetwork(ctx: ctx, promise: promise) - case .incomplete: - // We need to buffer this write and retry it. - bufferedWrites.append((data: data, promise: promise)) - case .failed(let err): - // TODO(cory): This is too aggressive. - channelClose(ctx: ctx) - promise?.fail(error: err) - } - } - - private func doUnbufferWrites(ctx: ChannelHandlerContext) { - // Early exit if there are no buffered writes. - if bufferedWrites.count == 0 { - return - } - - var originalError: OpenSSLError? = nil - var newBuffer: [BufferedWrite] = [] - - for bufferedWrite in bufferedWrites { - let promise = bufferedWrite.promise - var data = bufferedWrite.data - if let err = originalError { - promise?.fail(error: err) - continue - } else if newBuffer.count > 0 { - newBuffer.append((data: data, promise: promise)) - continue - } - - let result = connection!.writeDataToNetwork(&data) - - switch result { - case .complete: - writeDataToNetwork(ctx: ctx, promise: promise) - case .incomplete: - // We need to start a new buffer. At this point, all further writes - // must be buffered. - newBuffer.append((data: data, promise: promise)) - case .failed(let err): - // Once a write fails, all subsequent writes must fail. - channelClose(ctx: ctx) - promise?.fail(error: err) - originalError = err - } - } - - bufferedWrites = newBuffer - } - - private func discardBufferedWrites(reason: Error) { - for (_, promise) in bufferedWrites { - promise?.fail(error:reason) - } - - bufferedWrites = [] - } - private func writeDataToNetwork(ctx: ChannelHandlerContext, promise: Promise?) { // There may be no data to write, in which case we can just exit early. guard let dataToWrite = connection!.getDataForNetwork(allocator: ctx.channel!.allocator) else { - assert(promise == nil, "Promise present for nonexistent write.") + promise?.succeed(result: ()) return } - + ctx.writeAndFlush(data: self.wrapInboundOut(dataToWrite), promise: promise) } @@ -325,3 +259,116 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle ctx.close(promise: closePromise) } } + + +// MARK: Code that handles buffering/unbuffering writes. +extension OpenSSLHandler { + private enum BufferedEvent { + case write(BufferedWrite) + case flush(Promise?) + } + private typealias BufferedWrite = (data: ByteBuffer, promise: Promise?) + + private func bufferWrite(data: ByteBuffer, promise: Promise?) { + bufferedWrites.append(.write((data: data, promise: promise))) + } + + private func bufferFlush(promise: Promise?) { + bufferedWrites.append(.flush(promise)) + bufferedWrites.mark() + } + + private func discardBufferedWrites(reason: Error) { + while bufferedWrites.count > 0 { + let promise: Promise? + switch bufferedWrites.removeFirst() { + case .write(_, let p): + promise = p + case .flush(let p): + promise = p + } + + promise?.fail(error: reason) + } + } + + private func doUnbufferWrites(ctx: ChannelHandlerContext) { + // Return early if the user hasn't called flush. + guard bufferedWrites.hasMark() else { + return + } + + // These are some annoying variables we use to persist state across invocations of + // our closures. A better version of this code might be able to simplify this somewhat. + var writeCount = 0 + var promises: [Promise] = [] + + /// Given a byte buffer to encode, passes it to OpenSSL and handles the result. + func encodeWrite(buf: inout ByteBuffer, promise: Promise?) throws -> Bool { + let result = connection!.writeDataToNetwork(&buf) + + switch result { + case .complete: + if let promise = promise { promises.append(promise) } + writeCount += 1 + return true + case .incomplete: + // Ok, we can't write. Let's stop. + // We believe this can only ever happen on the first attempt to write. + precondition(writeCount == 0, "Unexpected change in OpenSSL state during write unbuffering: write count \(writeCount)") + return false + case .failed(let err): + // Once a write fails, all writes must fail. This includes prior writes + // that successfully made it through OpenSSL. + throw err + } + } + + /// Given a flush request, grabs the data from OpenSSL and flushes it to the network. + func flushData(userFlushPromise: Promise?) throws -> Bool { + // This is a flush. We can go ahead and flush now. + if let promise = userFlushPromise { promises.append(promise) } + let ourPromise: Promise = ctx.eventLoop.newPromise() + promises.forEach { ourPromise.futureResult.cascade(promise: $0) } + writeDataToNetwork(ctx: ctx, promise: ourPromise) + return true + } + + do { + try bufferedWrites.forEachElementUntilMark { element in + switch element { + case .write(var d, let p): + return try encodeWrite(buf: &d, promise: p) + case .flush(let p): + return try flushData(userFlushPromise: p) + } + } + } catch { + // We encountered an error, it's cleanup time. Close ourselves down. + channelClose(ctx: ctx) + // Fail any writes we've previously encoded but not flushed. + promises.forEach { $0.fail(error: error) } + // Fail everything else. + bufferedWrites.forEachRemoving { + switch $0 { + case .write(_, let p), .flush(let p): + p?.fail(error: error) + } + } + } + } +} + +fileprivate extension MarkedCircularBuffer { + fileprivate mutating func forEachElementUntilMark(callback: (E) throws -> Bool) rethrows { + while try self.hasMark() && callback(self.first!) { + _ = self.removeFirst() + } + } + + fileprivate mutating func forEachRemoving(callback: (E) -> Void) { + while self.count > 0 { + callback(self.removeFirst()) + } + } +} diff --git a/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest+XCTest.swift b/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest+XCTest.swift index 22ed86b2..32a292f8 100644 --- a/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest+XCTest.swift +++ b/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest+XCTest.swift @@ -30,6 +30,8 @@ extension OpenSSLIntegrationTest { ("testHandshakeEventSequencing", testHandshakeEventSequencing), ("testShutdownEventSequencing", testShutdownEventSequencing), ("testMultipleClose", testMultipleClose), + ("testCoalescedWrites", testCoalescedWrites), + ("testCoalescedWritesWithFutures", testCoalescedWritesWithFutures), ] } } diff --git a/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest.swift b/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest.swift index e93d60dc..b39beeb0 100644 --- a/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest.swift +++ b/Tests/NIOOpenSSLTests/OpenSSLIntegrationTest.swift @@ -25,10 +25,12 @@ private final class SimpleEchoServer: ChannelInboundHandler { public func channelRead(ctx: ChannelHandlerContext, data: IOData) { ctx.write(data: data, promise: nil) + ctx.fireChannelRead(data: data) } public func channelReadComplete(ctx: ChannelHandlerContext) { ctx.flush(promise: nil) + ctx.fireChannelReadComplete() } } @@ -46,10 +48,24 @@ private final class PromiseOnReadHandler: ChannelInboundHandler { public func channelRead(ctx: ChannelHandlerContext, data: IOData) { self.data = data + ctx.fireChannelRead(data: data) } public func channelReadComplete(ctx: ChannelHandlerContext) { promise.succeed(result: unwrapInboundIn(data!)) + ctx.fireChannelReadComplete() + } +} + +private final class WriteCountingHandler: ChannelOutboundHandler { + public typealias OutboundIn = Any + public typealias OutboundOut = Any + + public var writeCount = 0 + + public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise?) { + writeCount += 1 + ctx.write(data: data, promise: promise) } } @@ -150,13 +166,23 @@ private func serverTLSChannel(withContext: NIOOpenSSL.SSLContext, andHandlers: [ } private func clientTLSChannel(withContext: NIOOpenSSL.SSLContext, - andHandler: ChannelHandler, + preHandlers: [ChannelHandler], + postHandlers: [ChannelHandler], onGroup: EventLoopGroup, connectingTo: SocketAddress) throws -> Channel { return try ClientBootstrap(group: onGroup) .handler(handler: ChannelInitializer(initChannel: { channel in - return channel.pipeline.add(handler: OpenSSLHandler(context: withContext)).then(callback: { v2 in - return channel.pipeline.add(handler: andHandler) + let results = preHandlers.map { channel.pipeline.add(handler: $0) } + + return (results.last ?? channel.eventLoop.newSucceedFuture(result: ())).then(callback: { v2 in + return channel.pipeline.add(handler: OpenSSLHandler(context: withContext)).then(callback: { v2 in + let results = postHandlers.map { channel.pipeline.add(handler: $0) } + + // NB: This assumes that the futures will always fire in order. This is not necessarily guaranteed + // but in the absence of a way to gather a complete set of Future results, there is no other + // option. + return results.last ?? channel.eventLoop.newSucceedFuture(result: ()) + }) }) })).connect(to: connectingTo).wait() } @@ -196,7 +222,8 @@ class OpenSSLIntegrationTest: XCTestCase { } let clientChannel = try clientTLSChannel(withContext: ctx, - andHandler: PromiseOnReadHandler(promise: completionPromise), + preHandlers: [], + postHandlers: [PromiseOnReadHandler(promise: completionPromise)], onGroup: group, connectingTo: serverChannel.localAddress!) defer { @@ -229,7 +256,8 @@ class OpenSSLIntegrationTest: XCTestCase { } let clientChannel = try clientTLSChannel(withContext: ctx, - andHandler: SimpleEchoServer(), + preHandlers: [], + postHandlers: [SimpleEchoServer()], onGroup: group, connectingTo: serverChannel.localAddress!) defer { @@ -272,7 +300,8 @@ class OpenSSLIntegrationTest: XCTestCase { onGroup: group) let clientChannel = try clientTLSChannel(withContext: ctx, - andHandler: SimpleEchoServer(), + preHandlers: [], + postHandlers: [SimpleEchoServer()], onGroup: group, connectingTo: serverChannel.localAddress!) @@ -320,7 +349,8 @@ class OpenSSLIntegrationTest: XCTestCase { } let clientChannel = try clientTLSChannel(withContext: ctx, - andHandler: PromiseOnReadHandler(promise: completionPromise), + preHandlers: [], + postHandlers: [PromiseOnReadHandler(promise: completionPromise)], onGroup: group, connectingTo: serverChannel.localAddress!) defer { @@ -355,4 +385,106 @@ class OpenSSLIntegrationTest: XCTestCase { XCTAssert(promise.futureResult.fulfilled) } } + + func testCoalescedWrites() throws { + let ctx = try configuredSSLContext() + + let group = try MultiThreadedEventLoopGroup(numThreads: 1) + defer { + try! group.syncShutdownGracefully() + } + + let recorderHandler = EventRecorderHandler() + let serverChannel = try serverTLSChannel(withContext: ctx, andHandlers: [SimpleEchoServer()], onGroup: group) + defer { + _ = try! serverChannel.close().wait() + } + + let writeCounter = WriteCountingHandler() + let readPromise: Promise = group.next().newPromise() + let clientChannel = try clientTLSChannel(withContext: ctx, + preHandlers: [writeCounter], + postHandlers: [PromiseOnReadHandler(promise: readPromise)], + onGroup: group, + connectingTo: serverChannel.localAddress!) + defer { + _ = try! clientChannel.close().wait() + } + + // We're going to issue a number of small writes. Each of these should be coalesced together + // such that the underlying layer sees only one write for them. The total number of + // writes should be (after we flush) 3: one for Client Hello, one for Finished, and one + // for the coalesced writes. + var originalBuffer = clientChannel.allocator.buffer(capacity: 1) + originalBuffer.write(string: "A") + for _ in 0..<5 { + clientChannel.write(data: IOData(originalBuffer), promise: nil) + } + + try clientChannel.flush().wait() + let writeCount = try readPromise.futureResult.then { _ in + // Here we're in the I/O loop, so we know that no further channel action will happen + // while we dispatch this callback. This is the perfect time to check how many writes + // happened. + return writeCounter.writeCount + }.wait() + XCTAssertEqual(writeCount, 3) + } + + func testCoalescedWritesWithFutures() throws { + let ctx = try configuredSSLContext() + + let group = try MultiThreadedEventLoopGroup(numThreads: 1) + defer { + try! group.syncShutdownGracefully() + } + + let recorderHandler = EventRecorderHandler() + let serverChannel = try serverTLSChannel(withContext: ctx, andHandlers: [SimpleEchoServer()], onGroup: group) + defer { + _ = try! serverChannel.close().wait() + } + + let clientChannel = try clientTLSChannel(withContext: ctx, + preHandlers: [], + postHandlers: [], + onGroup: group, + connectingTo: serverChannel.localAddress!) + defer { + _ = try! clientChannel.close().wait() + } + + // We're going to issue a number of small writes. Each of these should be coalesced together + // and all their futures (along with the one for the flush) should fire, in order, with nothing + // missed. + var firedFutures: Array = [] + var originalBuffer = clientChannel.allocator.buffer(capacity: 1) + originalBuffer.write(string: "A") + for index in 0..<5 { + let promise: Promise = group.next().newPromise() + promise.futureResult.whenComplete { result in + switch result { + case .success: + XCTAssertEqual(firedFutures.count, index) + firedFutures.append(index) + case .failure: + XCTFail("Write promise failed: \(result)") + } + } + clientChannel.write(data: IOData(originalBuffer), promise: promise) + } + + let flushPromise: Promise = group.next().newPromise() + flushPromise.futureResult.whenComplete { result in + switch result { + case .success: + XCTAssertEqual(firedFutures, [0, 1, 2, 3, 4]) + case .failure: + XCTFail("Write promised failed: \(result)") + } + } + clientChannel.flush(promise: flushPromise) + + try flushPromise.futureResult.wait() + } }