Make ChannelCore writes use NIOAny

This commit is contained in:
Cory Benfield 2018-01-25 12:12:02 +00:00
parent 4a2425f5b7
commit 1fffa95ad1
8 changed files with 57 additions and 4 deletions

View File

@ -21,7 +21,7 @@ public protocol ChannelCore : class {
func register0(promise: EventLoopPromise<Void>?)
func bind0(to: SocketAddress, promise: EventLoopPromise<Void>?)
func connect0(to: SocketAddress, promise: EventLoopPromise<Void>?)
func write0(data: IOData, promise: EventLoopPromise<Void>?)
func write0(data: NIOAny, promise: EventLoopPromise<Void>?)
func flush0(promise: EventLoopPromise<Void>?)
func read0(promise: EventLoopPromise<Void>?)
func close0(error: Error, mode: CloseMode, promise: EventLoopPromise<Void>?)
@ -178,6 +178,10 @@ public enum ChannelError: Error {
/// A read operation reached end-of-file. This usually means the remote peer closed the socket but it's still
/// open locally.
case eof
/// A `Channel` `write` was made with a data type not supported by the channel type: e.g. an `AddressedEnvelope`
/// for a stream channel.
case writeDataUnsupported
}
extension ChannelError: Equatable {

View File

@ -708,7 +708,7 @@ private final class HeadChannelHandler : _ChannelOutboundHandler {
func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
if let channel = ctx.channel {
channel._unsafe.write0(data: data.forceAsIOData(), promise: promise)
channel._unsafe.write0(data: data, promise: promise)
} else {
promise?.fail(error: ChannelError.ioOnClosedChannel)
}

View File

@ -195,7 +195,12 @@ class EmbeddedChannelCore : ChannelCore {
promise?.succeed(result: ())
}
func write0(data: IOData, promise: EventLoopPromise<Void>?) {
func write0(data: NIOAny, promise: EventLoopPromise<Void>?) {
guard let data = data.tryAsIOData() else {
promise?.fail(error: ChannelError.writeDataUnsupported)
return
}
addToBuffer(buffer: &outboundBuffer, data: data)
promise?.succeed(result: ())
}

View File

@ -265,7 +265,7 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
}
}
public final func write0(data: IOData, promise: EventLoopPromise<Void>?) {
public final func write0(data: NIOAny, promise: EventLoopPromise<Void>?) {
assert(eventLoop.inEventLoop)
if closed {
@ -273,6 +273,12 @@ class BaseSocketChannel<T : BaseSocket> : SelectableChannel, ChannelCore {
promise?.fail(error: ChannelError.ioOnClosedChannel)
return
}
guard let data = data.tryAsIOData() else {
promise?.fail(error: ChannelError.writeDataUnsupported)
return
}
bufferPendingWrite(data: data, promise: promise)
}

View File

@ -50,6 +50,7 @@ extension ChannelTests {
("testCloseOutput", testCloseOutput),
("testCloseInput", testCloseInput),
("testHalfClosure", testHalfClosure),
("testRejectsInvalidData", testRejectsInvalidData),
]
}
}

View File

@ -1221,4 +1221,27 @@ public class ChannelTests: XCTestCase {
XCTAssertEqual(outputShutdown, outputShutdownEventReceived)
}
}
func testRejectsInvalidData() throws {
let group = MultiThreadedEventLoopGroup(numThreads: 1)
defer {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}
let serverChannel = try ServerBootstrap(group: group)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.bind(to: "127.0.0.1", on: 0).wait()
let clientChannel = try ClientBootstrap(group: group)
.connect(to: serverChannel.localAddress!).wait()
do {
try clientChannel.writeAndFlush(data: NIOAny(5)).wait()
XCTFail("Did not throw")
} catch ChannelError.writeDataUnsupported {
// All good
} catch {
XCTFail("Got \(error)")
}
}
}

View File

@ -34,6 +34,7 @@ extension EmbeddedChannelTest {
("testCloseOnInactiveIsOk", testCloseOnInactiveIsOk),
("testEmbeddedLifecycle", testEmbeddedLifecycle),
("testEmbeddedChannelAndPipelineAndChannelCoreShareTheEventLoop", testEmbeddedChannelAndPipelineAndChannelCoreShareTheEventLoop),
("testSendingIncorrectDataOnEmbeddedChannel", testSendingIncorrectDataOnEmbeddedChannel),
]
}
}

View File

@ -139,4 +139,17 @@ class EmbeddedChannelTest: XCTestCase {
XCTAssert(pipelineEventLoop === channel.eventLoop)
XCTAssert(pipelineEventLoop === (channel._unsafe as! EmbeddedChannelCore).eventLoop)
}
func testSendingIncorrectDataOnEmbeddedChannel() {
let channel = EmbeddedChannel()
do {
try channel.write(data: NIOAny(5)).wait()
XCTFail("Did not throw")
} catch ChannelError.writeDataUnsupported {
// All good
} catch {
XCTFail("Got \(error)")
}
}
}