Don't explode on zero-length writes
This commit is contained in:
parent
e4ef4c19f7
commit
9b9caa1d20
|
@ -357,7 +357,6 @@ extension OpenSSLHandler {
|
|||
|
||||
// These are some annoying variables we use to persist state across invocations of
|
||||
// our closures. A better version of this code might be able to simplify this somewhat.
|
||||
var writeCount = 0
|
||||
var promises: [EventLoopPromise<Void>] = []
|
||||
|
||||
/// Given a byte buffer to encode, passes it to OpenSSL and handles the result.
|
||||
|
@ -367,12 +366,9 @@ extension OpenSSLHandler {
|
|||
switch result {
|
||||
case .complete:
|
||||
if let promise = promise { promises.append(promise) }
|
||||
writeCount += 1
|
||||
return true
|
||||
case .incomplete:
|
||||
// Ok, we can't write. Let's stop.
|
||||
// We believe this can only ever happen on the first attempt to write.
|
||||
precondition(writeCount == 0, "Unexpected change in OpenSSL state during write unbuffering: write count \(writeCount)")
|
||||
return false
|
||||
case .failed(let err):
|
||||
// Once a write fails, all writes must fail. This includes prior writes
|
||||
|
@ -385,10 +381,15 @@ extension OpenSSLHandler {
|
|||
func flushData(userFlushPromise: EventLoopPromise<Void>?) throws -> Bool {
|
||||
// This is a flush. We can go ahead and flush now.
|
||||
if let promise = userFlushPromise { promises.append(promise) }
|
||||
flushWithPromises()
|
||||
return true
|
||||
}
|
||||
|
||||
func flushWithPromises() {
|
||||
let ourPromise: EventLoopPromise<Void> = ctx.eventLoop.newPromise()
|
||||
promises.forEach { ourPromise.futureResult.cascade(promise: $0) }
|
||||
writeDataToNetwork(ctx: ctx, promise: ourPromise)
|
||||
return true
|
||||
promises = []
|
||||
}
|
||||
|
||||
do {
|
||||
|
@ -400,6 +401,13 @@ extension OpenSSLHandler {
|
|||
return try flushData(userFlushPromise: p)
|
||||
}
|
||||
}
|
||||
|
||||
// If we got this far, but we have promises, it means that we weren't able to
|
||||
// write everything up to our mark: the SSL object started returning WANT_{READ,WRITE}
|
||||
// before we got there. That's ok: we'll shove the app data out to the network anyway.
|
||||
if promises.count > 0 {
|
||||
flushWithPromises()
|
||||
}
|
||||
} catch {
|
||||
// We encountered an error, it's cleanup time. Close ourselves down.
|
||||
channelClose(ctx: ctx)
|
||||
|
|
|
@ -239,6 +239,12 @@ internal final class SSLConnection {
|
|||
/// This call will only write the data into OpenSSL's internal buffers. It needs to be obtained
|
||||
/// by calling `getDataForNetwork` after this call completes.
|
||||
func writeDataToNetwork(_ data: inout ByteBuffer) -> AsyncOperationResult<Int32> {
|
||||
// OpenSSL does not allow calling SSL_write with zero-length buffers. Zero-length
|
||||
// writes always succeed.
|
||||
guard data.readableBytes > 0 else {
|
||||
return .complete(0)
|
||||
}
|
||||
|
||||
let writtenBytes = data.withUnsafeReadableBytes { (pointer) -> Int32 in
|
||||
return SSL_write(ssl, pointer.baseAddress, Int32(pointer.count))
|
||||
}
|
||||
|
|
|
@ -40,6 +40,7 @@ extension OpenSSLIntegrationTest {
|
|||
("testTrustStoreOnDisk", testTrustStoreOnDisk),
|
||||
("testChecksTrustStoreOnDisk", testChecksTrustStoreOnDisk),
|
||||
("testReadAfterCloseNotifyDoesntKillProcess", testReadAfterCloseNotifyDoesntKillProcess),
|
||||
("testZeroLengthWrite", testZeroLengthWrite),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -845,4 +845,44 @@ class OpenSSLIntegrationTest: XCTestCase {
|
|||
XCTFail("Encountered unexpected error: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
func testZeroLengthWrite() throws {
|
||||
let ctx = try configuredSSLContext()
|
||||
|
||||
let group = MultiThreadedEventLoopGroup(numThreads: 1)
|
||||
defer {
|
||||
try? group.syncShutdownGracefully()
|
||||
}
|
||||
|
||||
let completionPromise: EventLoopPromise<ByteBuffer> = group.next().newPromise()
|
||||
|
||||
let serverChannel = try serverTLSChannel(withContext: ctx,
|
||||
andHandlers: [PromiseOnReadHandler(promise: completionPromise)],
|
||||
onGroup: group)
|
||||
defer {
|
||||
_ = try? serverChannel.close().wait()
|
||||
}
|
||||
|
||||
let clientChannel = try clientTLSChannel(withContext: ctx,
|
||||
preHandlers: [],
|
||||
postHandlers: [],
|
||||
onGroup: group,
|
||||
connectingTo: serverChannel.localAddress!)
|
||||
defer {
|
||||
_ = try? clientChannel.close().wait()
|
||||
}
|
||||
|
||||
// Write several zero-length buffers *and* one with some actual data. Only one should
|
||||
// be written.
|
||||
var originalBuffer = clientChannel.allocator.buffer(capacity: 5)
|
||||
let promises = (0...5).map { _ in clientChannel.write(data: NIOAny(originalBuffer)) }
|
||||
originalBuffer.write(staticString: "hello")
|
||||
_ = try clientChannel.writeAndFlush(data: NIOAny(originalBuffer)).wait()
|
||||
|
||||
// At this time all the writes should have succeeded.
|
||||
XCTAssertTrue(promises.map { $0.fulfilled }.reduce(true, { $0 && $1 }))
|
||||
|
||||
let newBuffer = try completionPromise.futureResult.wait()
|
||||
XCTAssertEqual(newBuffer, originalBuffer)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue