Don't explode on zero-length writes

This commit is contained in:
Cory Benfield 2017-12-12 16:04:21 +00:00
parent e4ef4c19f7
commit 9b9caa1d20
4 changed files with 60 additions and 5 deletions

View File

@ -357,7 +357,6 @@ extension OpenSSLHandler {
// These are some annoying variables we use to persist state across invocations of
// our closures. A better version of this code might be able to simplify this somewhat.
var writeCount = 0
var promises: [EventLoopPromise<Void>] = []
/// Given a byte buffer to encode, passes it to OpenSSL and handles the result.
@ -367,12 +366,9 @@ extension OpenSSLHandler {
switch result {
case .complete:
if let promise = promise { promises.append(promise) }
writeCount += 1
return true
case .incomplete:
// Ok, we can't write. Let's stop.
// We believe this can only ever happen on the first attempt to write.
precondition(writeCount == 0, "Unexpected change in OpenSSL state during write unbuffering: write count \(writeCount)")
return false
case .failed(let err):
// Once a write fails, all writes must fail. This includes prior writes
@ -385,10 +381,15 @@ extension OpenSSLHandler {
func flushData(userFlushPromise: EventLoopPromise<Void>?) throws -> Bool {
// This is a flush. We can go ahead and flush now.
if let promise = userFlushPromise { promises.append(promise) }
flushWithPromises()
return true
}
func flushWithPromises() {
let ourPromise: EventLoopPromise<Void> = ctx.eventLoop.newPromise()
promises.forEach { ourPromise.futureResult.cascade(promise: $0) }
writeDataToNetwork(ctx: ctx, promise: ourPromise)
return true
promises = []
}
do {
@ -400,6 +401,13 @@ extension OpenSSLHandler {
return try flushData(userFlushPromise: p)
}
}
// If we got this far, but we have promises, it means that we weren't able to
// write everything up to our mark: the SSL object started returning WANT_{READ,WRITE}
// before we got there. That's ok: we'll shove the app data out to the network anyway.
if promises.count > 0 {
flushWithPromises()
}
} catch {
// We encountered an error, it's cleanup time. Close ourselves down.
channelClose(ctx: ctx)

View File

@ -239,6 +239,12 @@ internal final class SSLConnection {
/// This call will only write the data into OpenSSL's internal buffers. It needs to be obtained
/// by calling `getDataForNetwork` after this call completes.
func writeDataToNetwork(_ data: inout ByteBuffer) -> AsyncOperationResult<Int32> {
// OpenSSL does not allow calling SSL_write with zero-length buffers. Zero-length
// writes always succeed.
guard data.readableBytes > 0 else {
return .complete(0)
}
let writtenBytes = data.withUnsafeReadableBytes { (pointer) -> Int32 in
return SSL_write(ssl, pointer.baseAddress, Int32(pointer.count))
}

View File

@ -40,6 +40,7 @@ extension OpenSSLIntegrationTest {
("testTrustStoreOnDisk", testTrustStoreOnDisk),
("testChecksTrustStoreOnDisk", testChecksTrustStoreOnDisk),
("testReadAfterCloseNotifyDoesntKillProcess", testReadAfterCloseNotifyDoesntKillProcess),
("testZeroLengthWrite", testZeroLengthWrite),
]
}
}

View File

@ -845,4 +845,44 @@ class OpenSSLIntegrationTest: XCTestCase {
XCTFail("Encountered unexpected error: \(error)")
}
}
func testZeroLengthWrite() throws {
let ctx = try configuredSSLContext()
let group = MultiThreadedEventLoopGroup(numThreads: 1)
defer {
try? group.syncShutdownGracefully()
}
let completionPromise: EventLoopPromise<ByteBuffer> = group.next().newPromise()
let serverChannel = try serverTLSChannel(withContext: ctx,
andHandlers: [PromiseOnReadHandler(promise: completionPromise)],
onGroup: group)
defer {
_ = try? serverChannel.close().wait()
}
let clientChannel = try clientTLSChannel(withContext: ctx,
preHandlers: [],
postHandlers: [],
onGroup: group,
connectingTo: serverChannel.localAddress!)
defer {
_ = try? clientChannel.close().wait()
}
// Write several zero-length buffers *and* one with some actual data. Only one should
// be written.
var originalBuffer = clientChannel.allocator.buffer(capacity: 5)
let promises = (0...5).map { _ in clientChannel.write(data: NIOAny(originalBuffer)) }
originalBuffer.write(staticString: "hello")
_ = try clientChannel.writeAndFlush(data: NIOAny(originalBuffer)).wait()
// At this time all the writes should have succeeded.
XCTAssertTrue(promises.map { $0.fulfilled }.reduce(true, { $0 && $1 }))
let newBuffer = try completionPromise.futureResult.wait()
XCTAssertEqual(newBuffer, originalBuffer)
}
}