[OpenSSLHandler] Flush less frequently for outgoing application data.

This commit is contained in:
Cory Benfield 2017-10-04 17:29:38 +01:00
parent 55ae8a8606
commit c2f461e4b7
3 changed files with 265 additions and 84 deletions

View File

@ -21,9 +21,7 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
public typealias InboundIn = ByteBuffer
public typealias InboundOut = ByteBuffer
public typealias InboundUserEventOut = TLSUserEvent
private typealias BufferedWrite = (data: ByteBuffer, promise: Promise<Void>?)
private enum ConnectionState {
case idle
case handshaking
@ -35,12 +33,13 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
private let context: SSLContext
private var state: ConnectionState = .idle
private var connection: SSLConnection? = nil
private var bufferedWrites: [BufferedWrite] = []
private var bufferedWrites: MarkedCircularBuffer<BufferedEvent>
private var closePromise: Promise<Void>?
private var didDeliverData: Bool = false
public init (context: SSLContext) {
self.context = context
self.bufferedWrites = MarkedCircularBuffer(initialRingCapacity: 96) // 96 brings the total size of the buffer to just shy of one page
}
public func connect(ctx: ChannelHandlerContext, to address: SocketAddress, promise: Promise<Void>?) {
@ -136,8 +135,12 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
}
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
var binaryData = unwrapOutboundIn(data)
doEncodeData(data: &binaryData, ctx: ctx, promise: promise)
bufferWrite(data: unwrapOutboundIn(data), promise: promise)
}
public func flush(ctx: ChannelHandlerContext, promise: Promise<Void>?) {
bufferFlush(promise: promise)
doUnbufferWrites(ctx: ctx)
}
public func close(ctx: ChannelHandlerContext, promise: Promise<Void>?) {
@ -234,82 +237,13 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
}
}
private func doEncodeData(data: inout ByteBuffer, ctx: ChannelHandlerContext, promise: Promise<Void>?) {
if state == .closing || state == .closed {
// We're either shutting down or shut down: no further encoded data is allowed.
promise?.fail(error: NIOOpenSSLError.writeDuringTLSShutdown)
return
}
let result = connection!.writeDataToNetwork(&data)
switch result {
case .complete:
writeDataToNetwork(ctx: ctx, promise: promise)
case .incomplete:
// We need to buffer this write and retry it.
bufferedWrites.append((data: data, promise: promise))
case .failed(let err):
// TODO(cory): This is too aggressive.
channelClose(ctx: ctx)
promise?.fail(error: err)
}
}
private func doUnbufferWrites(ctx: ChannelHandlerContext) {
// Early exit if there are no buffered writes.
if bufferedWrites.count == 0 {
return
}
var originalError: OpenSSLError? = nil
var newBuffer: [BufferedWrite] = []
for bufferedWrite in bufferedWrites {
let promise = bufferedWrite.promise
var data = bufferedWrite.data
if let err = originalError {
promise?.fail(error: err)
continue
} else if newBuffer.count > 0 {
newBuffer.append((data: data, promise: promise))
continue
}
let result = connection!.writeDataToNetwork(&data)
switch result {
case .complete:
writeDataToNetwork(ctx: ctx, promise: promise)
case .incomplete:
// We need to start a new buffer. At this point, all further writes
// must be buffered.
newBuffer.append((data: data, promise: promise))
case .failed(let err):
// Once a write fails, all subsequent writes must fail.
channelClose(ctx: ctx)
promise?.fail(error: err)
originalError = err
}
}
bufferedWrites = newBuffer
}
private func discardBufferedWrites(reason: Error) {
for (_, promise) in bufferedWrites {
promise?.fail(error:reason)
}
bufferedWrites = []
}
private func writeDataToNetwork(ctx: ChannelHandlerContext, promise: Promise<Void>?) {
// There may be no data to write, in which case we can just exit early.
guard let dataToWrite = connection!.getDataForNetwork(allocator: ctx.channel!.allocator) else {
assert(promise == nil, "Promise present for nonexistent write.")
promise?.succeed(result: ())
return
}
ctx.writeAndFlush(data: self.wrapInboundOut(dataToWrite), promise: promise)
}
@ -325,3 +259,116 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
ctx.close(promise: closePromise)
}
}
// MARK: Code that handles buffering/unbuffering writes.
extension OpenSSLHandler {
private enum BufferedEvent {
case write(BufferedWrite)
case flush(Promise<Void>?)
}
private typealias BufferedWrite = (data: ByteBuffer, promise: Promise<Void>?)
private func bufferWrite(data: ByteBuffer, promise: Promise<Void>?) {
bufferedWrites.append(.write((data: data, promise: promise)))
}
private func bufferFlush(promise: Promise<Void>?) {
bufferedWrites.append(.flush(promise))
bufferedWrites.mark()
}
private func discardBufferedWrites(reason: Error) {
while bufferedWrites.count > 0 {
let promise: Promise<Void>?
switch bufferedWrites.removeFirst() {
case .write(_, let p):
promise = p
case .flush(let p):
promise = p
}
promise?.fail(error: reason)
}
}
private func doUnbufferWrites(ctx: ChannelHandlerContext) {
// Return early if the user hasn't called flush.
guard bufferedWrites.hasMark() else {
return
}
// These are some annoying variables we use to persist state across invocations of
// our closures. A better version of this code might be able to simplify this somewhat.
var writeCount = 0
var promises: [Promise<Void>] = []
/// Given a byte buffer to encode, passes it to OpenSSL and handles the result.
func encodeWrite(buf: inout ByteBuffer, promise: Promise<Void>?) throws -> Bool {
let result = connection!.writeDataToNetwork(&buf)
switch result {
case .complete:
if let promise = promise { promises.append(promise) }
writeCount += 1
return true
case .incomplete:
// Ok, we can't write. Let's stop.
// We believe this can only ever happen on the first attempt to write.
precondition(writeCount == 0, "Unexpected change in OpenSSL state during write unbuffering: write count \(writeCount)")
return false
case .failed(let err):
// Once a write fails, all writes must fail. This includes prior writes
// that successfully made it through OpenSSL.
throw err
}
}
/// Given a flush request, grabs the data from OpenSSL and flushes it to the network.
func flushData(userFlushPromise: Promise<Void>?) throws -> Bool {
// This is a flush. We can go ahead and flush now.
if let promise = userFlushPromise { promises.append(promise) }
let ourPromise: Promise<Void> = ctx.eventLoop.newPromise()
promises.forEach { ourPromise.futureResult.cascade(promise: $0) }
writeDataToNetwork(ctx: ctx, promise: ourPromise)
return true
}
do {
try bufferedWrites.forEachElementUntilMark { element in
switch element {
case .write(var d, let p):
return try encodeWrite(buf: &d, promise: p)
case .flush(let p):
return try flushData(userFlushPromise: p)
}
}
} catch {
// We encountered an error, it's cleanup time. Close ourselves down.
channelClose(ctx: ctx)
// Fail any writes we've previously encoded but not flushed.
promises.forEach { $0.fail(error: error) }
// Fail everything else.
bufferedWrites.forEachRemoving {
switch $0 {
case .write(_, let p), .flush(let p):
p?.fail(error: error)
}
}
}
}
}
fileprivate extension MarkedCircularBuffer {
fileprivate mutating func forEachElementUntilMark(callback: (E) throws -> Bool) rethrows {
while try self.hasMark() && callback(self.first!) {
_ = self.removeFirst()
}
}
fileprivate mutating func forEachRemoving(callback: (E) -> Void) {
while self.count > 0 {
callback(self.removeFirst())
}
}
}

View File

@ -30,6 +30,8 @@ extension OpenSSLIntegrationTest {
("testHandshakeEventSequencing", testHandshakeEventSequencing),
("testShutdownEventSequencing", testShutdownEventSequencing),
("testMultipleClose", testMultipleClose),
("testCoalescedWrites", testCoalescedWrites),
("testCoalescedWritesWithFutures", testCoalescedWritesWithFutures),
]
}
}

View File

@ -25,10 +25,12 @@ private final class SimpleEchoServer: ChannelInboundHandler {
public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
ctx.write(data: data, promise: nil)
ctx.fireChannelRead(data: data)
}
public func channelReadComplete(ctx: ChannelHandlerContext) {
ctx.flush(promise: nil)
ctx.fireChannelReadComplete()
}
}
@ -46,10 +48,24 @@ private final class PromiseOnReadHandler: ChannelInboundHandler {
public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
self.data = data
ctx.fireChannelRead(data: data)
}
public func channelReadComplete(ctx: ChannelHandlerContext) {
promise.succeed(result: unwrapInboundIn(data!))
ctx.fireChannelReadComplete()
}
}
private final class WriteCountingHandler: ChannelOutboundHandler {
public typealias OutboundIn = Any
public typealias OutboundOut = Any
public var writeCount = 0
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
writeCount += 1
ctx.write(data: data, promise: promise)
}
}
@ -150,13 +166,23 @@ private func serverTLSChannel(withContext: NIOOpenSSL.SSLContext, andHandlers: [
}
private func clientTLSChannel(withContext: NIOOpenSSL.SSLContext,
andHandler: ChannelHandler,
preHandlers: [ChannelHandler],
postHandlers: [ChannelHandler],
onGroup: EventLoopGroup,
connectingTo: SocketAddress) throws -> Channel {
return try ClientBootstrap(group: onGroup)
.handler(handler: ChannelInitializer(initChannel: { channel in
return channel.pipeline.add(handler: OpenSSLHandler(context: withContext)).then(callback: { v2 in
return channel.pipeline.add(handler: andHandler)
let results = preHandlers.map { channel.pipeline.add(handler: $0) }
return (results.last ?? channel.eventLoop.newSucceedFuture(result: ())).then(callback: { v2 in
return channel.pipeline.add(handler: OpenSSLHandler(context: withContext)).then(callback: { v2 in
let results = postHandlers.map { channel.pipeline.add(handler: $0) }
// NB: This assumes that the futures will always fire in order. This is not necessarily guaranteed
// but in the absence of a way to gather a complete set of Future results, there is no other
// option.
return results.last ?? channel.eventLoop.newSucceedFuture(result: ())
})
})
})).connect(to: connectingTo).wait()
}
@ -196,7 +222,8 @@ class OpenSSLIntegrationTest: XCTestCase {
}
let clientChannel = try clientTLSChannel(withContext: ctx,
andHandler: PromiseOnReadHandler(promise: completionPromise),
preHandlers: [],
postHandlers: [PromiseOnReadHandler(promise: completionPromise)],
onGroup: group,
connectingTo: serverChannel.localAddress!)
defer {
@ -229,7 +256,8 @@ class OpenSSLIntegrationTest: XCTestCase {
}
let clientChannel = try clientTLSChannel(withContext: ctx,
andHandler: SimpleEchoServer(),
preHandlers: [],
postHandlers: [SimpleEchoServer()],
onGroup: group,
connectingTo: serverChannel.localAddress!)
defer {
@ -272,7 +300,8 @@ class OpenSSLIntegrationTest: XCTestCase {
onGroup: group)
let clientChannel = try clientTLSChannel(withContext: ctx,
andHandler: SimpleEchoServer(),
preHandlers: [],
postHandlers: [SimpleEchoServer()],
onGroup: group,
connectingTo: serverChannel.localAddress!)
@ -320,7 +349,8 @@ class OpenSSLIntegrationTest: XCTestCase {
}
let clientChannel = try clientTLSChannel(withContext: ctx,
andHandler: PromiseOnReadHandler(promise: completionPromise),
preHandlers: [],
postHandlers: [PromiseOnReadHandler(promise: completionPromise)],
onGroup: group,
connectingTo: serverChannel.localAddress!)
defer {
@ -355,4 +385,106 @@ class OpenSSLIntegrationTest: XCTestCase {
XCTAssert(promise.futureResult.fulfilled)
}
}
func testCoalescedWrites() throws {
let ctx = try configuredSSLContext()
let group = try MultiThreadedEventLoopGroup(numThreads: 1)
defer {
try! group.syncShutdownGracefully()
}
let recorderHandler = EventRecorderHandler<TLSUserEvent>()
let serverChannel = try serverTLSChannel(withContext: ctx, andHandlers: [SimpleEchoServer()], onGroup: group)
defer {
_ = try! serverChannel.close().wait()
}
let writeCounter = WriteCountingHandler()
let readPromise: Promise<ByteBuffer> = group.next().newPromise()
let clientChannel = try clientTLSChannel(withContext: ctx,
preHandlers: [writeCounter],
postHandlers: [PromiseOnReadHandler(promise: readPromise)],
onGroup: group,
connectingTo: serverChannel.localAddress!)
defer {
_ = try! clientChannel.close().wait()
}
// We're going to issue a number of small writes. Each of these should be coalesced together
// such that the underlying layer sees only one write for them. The total number of
// writes should be (after we flush) 3: one for Client Hello, one for Finished, and one
// for the coalesced writes.
var originalBuffer = clientChannel.allocator.buffer(capacity: 1)
originalBuffer.write(string: "A")
for _ in 0..<5 {
clientChannel.write(data: IOData(originalBuffer), promise: nil)
}
try clientChannel.flush().wait()
let writeCount = try readPromise.futureResult.then { _ in
// Here we're in the I/O loop, so we know that no further channel action will happen
// while we dispatch this callback. This is the perfect time to check how many writes
// happened.
return writeCounter.writeCount
}.wait()
XCTAssertEqual(writeCount, 3)
}
func testCoalescedWritesWithFutures() throws {
let ctx = try configuredSSLContext()
let group = try MultiThreadedEventLoopGroup(numThreads: 1)
defer {
try! group.syncShutdownGracefully()
}
let recorderHandler = EventRecorderHandler<TLSUserEvent>()
let serverChannel = try serverTLSChannel(withContext: ctx, andHandlers: [SimpleEchoServer()], onGroup: group)
defer {
_ = try! serverChannel.close().wait()
}
let clientChannel = try clientTLSChannel(withContext: ctx,
preHandlers: [],
postHandlers: [],
onGroup: group,
connectingTo: serverChannel.localAddress!)
defer {
_ = try! clientChannel.close().wait()
}
// We're going to issue a number of small writes. Each of these should be coalesced together
// and all their futures (along with the one for the flush) should fire, in order, with nothing
// missed.
var firedFutures: Array<Int> = []
var originalBuffer = clientChannel.allocator.buffer(capacity: 1)
originalBuffer.write(string: "A")
for index in 0..<5 {
let promise: Promise<Void> = group.next().newPromise()
promise.futureResult.whenComplete { result in
switch result {
case .success:
XCTAssertEqual(firedFutures.count, index)
firedFutures.append(index)
case .failure:
XCTFail("Write promise failed: \(result)")
}
}
clientChannel.write(data: IOData(originalBuffer), promise: promise)
}
let flushPromise: Promise<Void> = group.next().newPromise()
flushPromise.futureResult.whenComplete { result in
switch result {
case .success:
XCTAssertEqual(firedFutures, [0, 1, 2, 3, 4])
case .failure:
XCTFail("Write promised failed: \(result)")
}
}
clientChannel.flush(promise: flushPromise)
try flushPromise.futureResult.wait()
}
}