[OpenSSLHandler] Flush less frequently for outgoing application data.
This commit is contained in:
parent
55ae8a8606
commit
c2f461e4b7
|
@ -22,8 +22,6 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
|
|||
public typealias InboundOut = ByteBuffer
|
||||
public typealias InboundUserEventOut = TLSUserEvent
|
||||
|
||||
private typealias BufferedWrite = (data: ByteBuffer, promise: Promise<Void>?)
|
||||
|
||||
private enum ConnectionState {
|
||||
case idle
|
||||
case handshaking
|
||||
|
@ -35,12 +33,13 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
|
|||
private let context: SSLContext
|
||||
private var state: ConnectionState = .idle
|
||||
private var connection: SSLConnection? = nil
|
||||
private var bufferedWrites: [BufferedWrite] = []
|
||||
private var bufferedWrites: MarkedCircularBuffer<BufferedEvent>
|
||||
private var closePromise: Promise<Void>?
|
||||
private var didDeliverData: Bool = false
|
||||
|
||||
public init (context: SSLContext) {
|
||||
self.context = context
|
||||
self.bufferedWrites = MarkedCircularBuffer(initialRingCapacity: 96) // 96 brings the total size of the buffer to just shy of one page
|
||||
}
|
||||
|
||||
public func connect(ctx: ChannelHandlerContext, to address: SocketAddress, promise: Promise<Void>?) {
|
||||
|
@ -136,8 +135,12 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
|
|||
}
|
||||
|
||||
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
|
||||
var binaryData = unwrapOutboundIn(data)
|
||||
doEncodeData(data: &binaryData, ctx: ctx, promise: promise)
|
||||
bufferWrite(data: unwrapOutboundIn(data), promise: promise)
|
||||
}
|
||||
|
||||
public func flush(ctx: ChannelHandlerContext, promise: Promise<Void>?) {
|
||||
bufferFlush(promise: promise)
|
||||
doUnbufferWrites(ctx: ctx)
|
||||
}
|
||||
|
||||
public func close(ctx: ChannelHandlerContext, promise: Promise<Void>?) {
|
||||
|
@ -234,79 +237,10 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
|
|||
}
|
||||
}
|
||||
|
||||
private func doEncodeData(data: inout ByteBuffer, ctx: ChannelHandlerContext, promise: Promise<Void>?) {
|
||||
if state == .closing || state == .closed {
|
||||
// We're either shutting down or shut down: no further encoded data is allowed.
|
||||
promise?.fail(error: NIOOpenSSLError.writeDuringTLSShutdown)
|
||||
return
|
||||
}
|
||||
|
||||
let result = connection!.writeDataToNetwork(&data)
|
||||
switch result {
|
||||
case .complete:
|
||||
writeDataToNetwork(ctx: ctx, promise: promise)
|
||||
case .incomplete:
|
||||
// We need to buffer this write and retry it.
|
||||
bufferedWrites.append((data: data, promise: promise))
|
||||
case .failed(let err):
|
||||
// TODO(cory): This is too aggressive.
|
||||
channelClose(ctx: ctx)
|
||||
promise?.fail(error: err)
|
||||
}
|
||||
}
|
||||
|
||||
private func doUnbufferWrites(ctx: ChannelHandlerContext) {
|
||||
// Early exit if there are no buffered writes.
|
||||
if bufferedWrites.count == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var originalError: OpenSSLError? = nil
|
||||
var newBuffer: [BufferedWrite] = []
|
||||
|
||||
for bufferedWrite in bufferedWrites {
|
||||
let promise = bufferedWrite.promise
|
||||
var data = bufferedWrite.data
|
||||
if let err = originalError {
|
||||
promise?.fail(error: err)
|
||||
continue
|
||||
} else if newBuffer.count > 0 {
|
||||
newBuffer.append((data: data, promise: promise))
|
||||
continue
|
||||
}
|
||||
|
||||
let result = connection!.writeDataToNetwork(&data)
|
||||
|
||||
switch result {
|
||||
case .complete:
|
||||
writeDataToNetwork(ctx: ctx, promise: promise)
|
||||
case .incomplete:
|
||||
// We need to start a new buffer. At this point, all further writes
|
||||
// must be buffered.
|
||||
newBuffer.append((data: data, promise: promise))
|
||||
case .failed(let err):
|
||||
// Once a write fails, all subsequent writes must fail.
|
||||
channelClose(ctx: ctx)
|
||||
promise?.fail(error: err)
|
||||
originalError = err
|
||||
}
|
||||
}
|
||||
|
||||
bufferedWrites = newBuffer
|
||||
}
|
||||
|
||||
private func discardBufferedWrites(reason: Error) {
|
||||
for (_, promise) in bufferedWrites {
|
||||
promise?.fail(error:reason)
|
||||
}
|
||||
|
||||
bufferedWrites = []
|
||||
}
|
||||
|
||||
private func writeDataToNetwork(ctx: ChannelHandlerContext, promise: Promise<Void>?) {
|
||||
// There may be no data to write, in which case we can just exit early.
|
||||
guard let dataToWrite = connection!.getDataForNetwork(allocator: ctx.channel!.allocator) else {
|
||||
assert(promise == nil, "Promise present for nonexistent write.")
|
||||
promise?.succeed(result: ())
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -325,3 +259,116 @@ public final class OpenSSLHandler : ChannelInboundHandler, ChannelOutboundHandle
|
|||
ctx.close(promise: closePromise)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// MARK: Code that handles buffering/unbuffering writes.
|
||||
extension OpenSSLHandler {
|
||||
private enum BufferedEvent {
|
||||
case write(BufferedWrite)
|
||||
case flush(Promise<Void>?)
|
||||
}
|
||||
private typealias BufferedWrite = (data: ByteBuffer, promise: Promise<Void>?)
|
||||
|
||||
private func bufferWrite(data: ByteBuffer, promise: Promise<Void>?) {
|
||||
bufferedWrites.append(.write((data: data, promise: promise)))
|
||||
}
|
||||
|
||||
private func bufferFlush(promise: Promise<Void>?) {
|
||||
bufferedWrites.append(.flush(promise))
|
||||
bufferedWrites.mark()
|
||||
}
|
||||
|
||||
private func discardBufferedWrites(reason: Error) {
|
||||
while bufferedWrites.count > 0 {
|
||||
let promise: Promise<Void>?
|
||||
switch bufferedWrites.removeFirst() {
|
||||
case .write(_, let p):
|
||||
promise = p
|
||||
case .flush(let p):
|
||||
promise = p
|
||||
}
|
||||
|
||||
promise?.fail(error: reason)
|
||||
}
|
||||
}
|
||||
|
||||
private func doUnbufferWrites(ctx: ChannelHandlerContext) {
|
||||
// Return early if the user hasn't called flush.
|
||||
guard bufferedWrites.hasMark() else {
|
||||
return
|
||||
}
|
||||
|
||||
// These are some annoying variables we use to persist state across invocations of
|
||||
// our closures. A better version of this code might be able to simplify this somewhat.
|
||||
var writeCount = 0
|
||||
var promises: [Promise<Void>] = []
|
||||
|
||||
/// Given a byte buffer to encode, passes it to OpenSSL and handles the result.
|
||||
func encodeWrite(buf: inout ByteBuffer, promise: Promise<Void>?) throws -> Bool {
|
||||
let result = connection!.writeDataToNetwork(&buf)
|
||||
|
||||
switch result {
|
||||
case .complete:
|
||||
if let promise = promise { promises.append(promise) }
|
||||
writeCount += 1
|
||||
return true
|
||||
case .incomplete:
|
||||
// Ok, we can't write. Let's stop.
|
||||
// We believe this can only ever happen on the first attempt to write.
|
||||
precondition(writeCount == 0, "Unexpected change in OpenSSL state during write unbuffering: write count \(writeCount)")
|
||||
return false
|
||||
case .failed(let err):
|
||||
// Once a write fails, all writes must fail. This includes prior writes
|
||||
// that successfully made it through OpenSSL.
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
||||
/// Given a flush request, grabs the data from OpenSSL and flushes it to the network.
|
||||
func flushData(userFlushPromise: Promise<Void>?) throws -> Bool {
|
||||
// This is a flush. We can go ahead and flush now.
|
||||
if let promise = userFlushPromise { promises.append(promise) }
|
||||
let ourPromise: Promise<Void> = ctx.eventLoop.newPromise()
|
||||
promises.forEach { ourPromise.futureResult.cascade(promise: $0) }
|
||||
writeDataToNetwork(ctx: ctx, promise: ourPromise)
|
||||
return true
|
||||
}
|
||||
|
||||
do {
|
||||
try bufferedWrites.forEachElementUntilMark { element in
|
||||
switch element {
|
||||
case .write(var d, let p):
|
||||
return try encodeWrite(buf: &d, promise: p)
|
||||
case .flush(let p):
|
||||
return try flushData(userFlushPromise: p)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// We encountered an error, it's cleanup time. Close ourselves down.
|
||||
channelClose(ctx: ctx)
|
||||
// Fail any writes we've previously encoded but not flushed.
|
||||
promises.forEach { $0.fail(error: error) }
|
||||
// Fail everything else.
|
||||
bufferedWrites.forEachRemoving {
|
||||
switch $0 {
|
||||
case .write(_, let p), .flush(let p):
|
||||
p?.fail(error: error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fileprivate extension MarkedCircularBuffer {
|
||||
fileprivate mutating func forEachElementUntilMark(callback: (E) throws -> Bool) rethrows {
|
||||
while try self.hasMark() && callback(self.first!) {
|
||||
_ = self.removeFirst()
|
||||
}
|
||||
}
|
||||
|
||||
fileprivate mutating func forEachRemoving(callback: (E) -> Void) {
|
||||
while self.count > 0 {
|
||||
callback(self.removeFirst())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,6 +30,8 @@ extension OpenSSLIntegrationTest {
|
|||
("testHandshakeEventSequencing", testHandshakeEventSequencing),
|
||||
("testShutdownEventSequencing", testShutdownEventSequencing),
|
||||
("testMultipleClose", testMultipleClose),
|
||||
("testCoalescedWrites", testCoalescedWrites),
|
||||
("testCoalescedWritesWithFutures", testCoalescedWritesWithFutures),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,10 +25,12 @@ private final class SimpleEchoServer: ChannelInboundHandler {
|
|||
|
||||
public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
|
||||
ctx.write(data: data, promise: nil)
|
||||
ctx.fireChannelRead(data: data)
|
||||
}
|
||||
|
||||
public func channelReadComplete(ctx: ChannelHandlerContext) {
|
||||
ctx.flush(promise: nil)
|
||||
ctx.fireChannelReadComplete()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -46,10 +48,24 @@ private final class PromiseOnReadHandler: ChannelInboundHandler {
|
|||
|
||||
public func channelRead(ctx: ChannelHandlerContext, data: IOData) {
|
||||
self.data = data
|
||||
ctx.fireChannelRead(data: data)
|
||||
}
|
||||
|
||||
public func channelReadComplete(ctx: ChannelHandlerContext) {
|
||||
promise.succeed(result: unwrapInboundIn(data!))
|
||||
ctx.fireChannelReadComplete()
|
||||
}
|
||||
}
|
||||
|
||||
private final class WriteCountingHandler: ChannelOutboundHandler {
|
||||
public typealias OutboundIn = Any
|
||||
public typealias OutboundOut = Any
|
||||
|
||||
public var writeCount = 0
|
||||
|
||||
public func write(ctx: ChannelHandlerContext, data: IOData, promise: Promise<Void>?) {
|
||||
writeCount += 1
|
||||
ctx.write(data: data, promise: promise)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -150,13 +166,23 @@ private func serverTLSChannel(withContext: NIOOpenSSL.SSLContext, andHandlers: [
|
|||
}
|
||||
|
||||
private func clientTLSChannel(withContext: NIOOpenSSL.SSLContext,
|
||||
andHandler: ChannelHandler,
|
||||
preHandlers: [ChannelHandler],
|
||||
postHandlers: [ChannelHandler],
|
||||
onGroup: EventLoopGroup,
|
||||
connectingTo: SocketAddress) throws -> Channel {
|
||||
return try ClientBootstrap(group: onGroup)
|
||||
.handler(handler: ChannelInitializer(initChannel: { channel in
|
||||
let results = preHandlers.map { channel.pipeline.add(handler: $0) }
|
||||
|
||||
return (results.last ?? channel.eventLoop.newSucceedFuture(result: ())).then(callback: { v2 in
|
||||
return channel.pipeline.add(handler: OpenSSLHandler(context: withContext)).then(callback: { v2 in
|
||||
return channel.pipeline.add(handler: andHandler)
|
||||
let results = postHandlers.map { channel.pipeline.add(handler: $0) }
|
||||
|
||||
// NB: This assumes that the futures will always fire in order. This is not necessarily guaranteed
|
||||
// but in the absence of a way to gather a complete set of Future results, there is no other
|
||||
// option.
|
||||
return results.last ?? channel.eventLoop.newSucceedFuture(result: ())
|
||||
})
|
||||
})
|
||||
})).connect(to: connectingTo).wait()
|
||||
}
|
||||
|
@ -196,7 +222,8 @@ class OpenSSLIntegrationTest: XCTestCase {
|
|||
}
|
||||
|
||||
let clientChannel = try clientTLSChannel(withContext: ctx,
|
||||
andHandler: PromiseOnReadHandler(promise: completionPromise),
|
||||
preHandlers: [],
|
||||
postHandlers: [PromiseOnReadHandler(promise: completionPromise)],
|
||||
onGroup: group,
|
||||
connectingTo: serverChannel.localAddress!)
|
||||
defer {
|
||||
|
@ -229,7 +256,8 @@ class OpenSSLIntegrationTest: XCTestCase {
|
|||
}
|
||||
|
||||
let clientChannel = try clientTLSChannel(withContext: ctx,
|
||||
andHandler: SimpleEchoServer(),
|
||||
preHandlers: [],
|
||||
postHandlers: [SimpleEchoServer()],
|
||||
onGroup: group,
|
||||
connectingTo: serverChannel.localAddress!)
|
||||
defer {
|
||||
|
@ -272,7 +300,8 @@ class OpenSSLIntegrationTest: XCTestCase {
|
|||
onGroup: group)
|
||||
|
||||
let clientChannel = try clientTLSChannel(withContext: ctx,
|
||||
andHandler: SimpleEchoServer(),
|
||||
preHandlers: [],
|
||||
postHandlers: [SimpleEchoServer()],
|
||||
onGroup: group,
|
||||
connectingTo: serverChannel.localAddress!)
|
||||
|
||||
|
@ -320,7 +349,8 @@ class OpenSSLIntegrationTest: XCTestCase {
|
|||
}
|
||||
|
||||
let clientChannel = try clientTLSChannel(withContext: ctx,
|
||||
andHandler: PromiseOnReadHandler(promise: completionPromise),
|
||||
preHandlers: [],
|
||||
postHandlers: [PromiseOnReadHandler(promise: completionPromise)],
|
||||
onGroup: group,
|
||||
connectingTo: serverChannel.localAddress!)
|
||||
defer {
|
||||
|
@ -355,4 +385,106 @@ class OpenSSLIntegrationTest: XCTestCase {
|
|||
XCTAssert(promise.futureResult.fulfilled)
|
||||
}
|
||||
}
|
||||
|
||||
func testCoalescedWrites() throws {
|
||||
let ctx = try configuredSSLContext()
|
||||
|
||||
let group = try MultiThreadedEventLoopGroup(numThreads: 1)
|
||||
defer {
|
||||
try! group.syncShutdownGracefully()
|
||||
}
|
||||
|
||||
let recorderHandler = EventRecorderHandler<TLSUserEvent>()
|
||||
let serverChannel = try serverTLSChannel(withContext: ctx, andHandlers: [SimpleEchoServer()], onGroup: group)
|
||||
defer {
|
||||
_ = try! serverChannel.close().wait()
|
||||
}
|
||||
|
||||
let writeCounter = WriteCountingHandler()
|
||||
let readPromise: Promise<ByteBuffer> = group.next().newPromise()
|
||||
let clientChannel = try clientTLSChannel(withContext: ctx,
|
||||
preHandlers: [writeCounter],
|
||||
postHandlers: [PromiseOnReadHandler(promise: readPromise)],
|
||||
onGroup: group,
|
||||
connectingTo: serverChannel.localAddress!)
|
||||
defer {
|
||||
_ = try! clientChannel.close().wait()
|
||||
}
|
||||
|
||||
// We're going to issue a number of small writes. Each of these should be coalesced together
|
||||
// such that the underlying layer sees only one write for them. The total number of
|
||||
// writes should be (after we flush) 3: one for Client Hello, one for Finished, and one
|
||||
// for the coalesced writes.
|
||||
var originalBuffer = clientChannel.allocator.buffer(capacity: 1)
|
||||
originalBuffer.write(string: "A")
|
||||
for _ in 0..<5 {
|
||||
clientChannel.write(data: IOData(originalBuffer), promise: nil)
|
||||
}
|
||||
|
||||
try clientChannel.flush().wait()
|
||||
let writeCount = try readPromise.futureResult.then { _ in
|
||||
// Here we're in the I/O loop, so we know that no further channel action will happen
|
||||
// while we dispatch this callback. This is the perfect time to check how many writes
|
||||
// happened.
|
||||
return writeCounter.writeCount
|
||||
}.wait()
|
||||
XCTAssertEqual(writeCount, 3)
|
||||
}
|
||||
|
||||
func testCoalescedWritesWithFutures() throws {
|
||||
let ctx = try configuredSSLContext()
|
||||
|
||||
let group = try MultiThreadedEventLoopGroup(numThreads: 1)
|
||||
defer {
|
||||
try! group.syncShutdownGracefully()
|
||||
}
|
||||
|
||||
let recorderHandler = EventRecorderHandler<TLSUserEvent>()
|
||||
let serverChannel = try serverTLSChannel(withContext: ctx, andHandlers: [SimpleEchoServer()], onGroup: group)
|
||||
defer {
|
||||
_ = try! serverChannel.close().wait()
|
||||
}
|
||||
|
||||
let clientChannel = try clientTLSChannel(withContext: ctx,
|
||||
preHandlers: [],
|
||||
postHandlers: [],
|
||||
onGroup: group,
|
||||
connectingTo: serverChannel.localAddress!)
|
||||
defer {
|
||||
_ = try! clientChannel.close().wait()
|
||||
}
|
||||
|
||||
// We're going to issue a number of small writes. Each of these should be coalesced together
|
||||
// and all their futures (along with the one for the flush) should fire, in order, with nothing
|
||||
// missed.
|
||||
var firedFutures: Array<Int> = []
|
||||
var originalBuffer = clientChannel.allocator.buffer(capacity: 1)
|
||||
originalBuffer.write(string: "A")
|
||||
for index in 0..<5 {
|
||||
let promise: Promise<Void> = group.next().newPromise()
|
||||
promise.futureResult.whenComplete { result in
|
||||
switch result {
|
||||
case .success:
|
||||
XCTAssertEqual(firedFutures.count, index)
|
||||
firedFutures.append(index)
|
||||
case .failure:
|
||||
XCTFail("Write promise failed: \(result)")
|
||||
}
|
||||
}
|
||||
clientChannel.write(data: IOData(originalBuffer), promise: promise)
|
||||
}
|
||||
|
||||
let flushPromise: Promise<Void> = group.next().newPromise()
|
||||
flushPromise.futureResult.whenComplete { result in
|
||||
switch result {
|
||||
case .success:
|
||||
XCTAssertEqual(firedFutures, [0, 1, 2, 3, 4])
|
||||
case .failure:
|
||||
XCTFail("Write promised failed: \(result)")
|
||||
}
|
||||
}
|
||||
clientChannel.flush(promise: flushPromise)
|
||||
|
||||
try flushPromise.futureResult.wait()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue