Add initial support for connected datagram sockets (#2084)

* socket: Make destinationPtr param optional in sendmsg(...)

Signed-off-by: Si Beaumont <beaumont@apple.com>

* pdwm: Fixup documentation: scalar writes use sendmsg, not sendto

Signed-off-by: Si Beaumont <beaumont@apple.com>

* pdwm: Make sockaddr pointer param optional in scalarWriteOperation

Signed-off-by: Si Beaumont <beaumont@apple.com>

* pdwm: Add isConnected property to PendingDatagramWritesState

Signed-off-by: Si Beaumont <beaumont@apple.com>

* pdwm: If socket is connected use NULL msg_name in sendmsg(2)

Signed-off-by: Si Beaumont <beaumont@apple.com>

* BaseSocketChannel: Support connect after bind

Signed-off-by: Si Beaumont <beaumont@apple.com>

* DatagramChannel: Implement connectSocket(to:)

Signed-off-by: Si Beaumont <beaumont@apple.com>

* bootstrap: Rename bind0(makeChannel:registerAndBind:) to withNewChannel(makeChannel:bringup:)

Signed-off-by: Si Beaumont <beaumont@apple.com>

* bootstrap: Add set of DatagramBootstrap.connect(...) APIs

Signed-off-by: Si Beaumont <beaumont@apple.com>

* test: Remove DatagramChannelTests.testConnectionFails

Signed-off-by: Si Beaumont <beaumont@apple.com>

* test: Add ConnectedDatagramChannelTests, inheriting from DatagramChannelTests

Signed-off-by: Si Beaumont <beaumont@apple.com>

* NIOUDPEchoClient: Use connected-mode UDP

Signed-off-by: Si Beaumont <beaumont@apple.com>

* soundness: Update copyright notice

Signed-off-by: Si Beaumont <beaumont@apple.com>

* fixup: cleanup bootstrap APIs

Signed-off-by: Si Beaumont <beaumont@apple.com>

* pdwm: Check address of pending write if connected and add test

Signed-off-by: Si Beaumont <beaumont@apple.com>

* Revert "pdwm: Check address of pending write if connected and add test"

This reverts commit a4ee0756d5.

* channel: Fail buffered writes on connect and validate writes when connected

Signed-off-by: Si Beaumont <beaumont@apple.com>

* Run soundness.sh to get linux tests generated

Signed-off-by: Si Beaumont <beaumont@apple.com>

* NIOUDPEchoClient: Connect socket to remote only if --connect is used

Signed-off-by: Si Beaumont <beaumont@apple.com>

* socket: Support ByteBuffer (without AddressedEnvelope) for DatagramChannel

Signed-off-by: Si Beaumont <beaumont@apple.com>

* test: Simplify some test code

Signed-off-by: Si Beaumont <beaumont@apple.com>

* pdwm: Factor out common, private add(_ pendingWrite:)

Signed-off-by: Si Beaumont <beaumont@apple.com>

* channel: Support AddressedEnvelope on connected socket for control messages

Signed-off-by: Si Beaumont <beaumont@apple.com>

* channel: Defer to common unwrapData for error handling

Signed-off-by: Si Beaumont <beaumont@apple.com>

* channel: Throw more specific (new) errors, instead of IOError

Signed-off-by: Si Beaumont <beaumont@apple.com>

* SocketChannelLifecycleManager: Add supportsReconnect boolean property, used in DatagramChannel

Signed-off-by: Si Beaumont <beaumont@apple.com>
This commit is contained in:
Si Beaumont 2022-05-31 13:21:41 +01:00 committed by GitHub
parent 083eba3652
commit 9bf5075241
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 543 additions and 65 deletions

View File

@ -378,6 +378,25 @@ extension ChannelError: Equatable { }
/// The removal of a `ChannelHandler` using `ChannelPipeline.removeHandler` has been attempted more than once.
public struct NIOAttemptedToRemoveHandlerMultipleTimesError: Error {}
public enum DatagramChannelError {
public struct WriteOnUnconnectedSocketWithoutAddress: Error {
public init() {}
}
public struct WriteOnConnectedSocketWithInvalidAddress: Error {
let envelopeRemoteAddress: SocketAddress
let connectedRemoteAddress: SocketAddress
public init(
envelopeRemoteAddress: SocketAddress,
connectedRemoteAddress: SocketAddress
) {
self.envelopeRemoteAddress = envelopeRemoteAddress
self.connectedRemoteAddress = connectedRemoteAddress
}
}
}
/// An `Channel` related event that is passed through the `ChannelPipeline` to notify the user.
public enum ChannelEvent: Equatable, NIOSendable {
/// `ChannelOptions.allowRemoteHalfClosure` is `true` and input portion of the `Channel` was closed.

View File

@ -42,6 +42,9 @@ private struct SocketChannelLifecycleManager {
// note: this can be `false` on a deactivated channel, we might just have torn it down.
var hasSeenEOFNotification: Bool = false
// Should we support transition from `active` to `active`, used by datagram sockets.
let supportsReconnect: Bool
private var currentState: State = .fresh {
didSet {
self.eventLoop.assertInEventLoop()
@ -58,9 +61,14 @@ private struct SocketChannelLifecycleManager {
// MARK: API
// isActiveAtomic needs to be injected as it's accessed from arbitrary threads and `SocketChannelLifecycleManager` is usually held mutable
internal init(eventLoop: EventLoop, isActiveAtomic: NIOAtomic<Bool>) {
internal init(
eventLoop: EventLoop,
isActiveAtomic: NIOAtomic<Bool>,
supportReconnect: Bool
) {
self.eventLoop = eventLoop
self.isActiveAtomic = isActiveAtomic
self.supportsReconnect = supportReconnect
}
// this is called from Channel's deinit, so don't assert we're on the EventLoop!
@ -140,6 +148,12 @@ private struct SocketChannelLifecycleManager {
pipeline.syncOperations.fireChannelUnregistered()
}
// origin: .activated
case (.activated, .activate) where self.supportsReconnect:
return { promise, pipeline in
promise?.succeed(())
}
// bad transitions
case (.fresh, .activate), // should go through .registered first
(.preRegistered, .activate), // need to first be fully registered
@ -439,7 +453,13 @@ class BaseSocketChannel<SocketType: BaseSocketProtocol>: SelectableChannel, Chan
}
// MARK: Common base socket logic.
init(socket: SocketType, parent: Channel?, eventLoop: SelectableEventLoop, recvAllocator: RecvByteBufferAllocator) throws {
init(
socket: SocketType,
parent: Channel?,
eventLoop: SelectableEventLoop,
recvAllocator: RecvByteBufferAllocator,
supportReconnect: Bool
) throws {
self._bufferAllocatorCache = self.bufferAllocator
self.socket = socket
self.selectableEventLoop = eventLoop
@ -448,7 +468,11 @@ class BaseSocketChannel<SocketType: BaseSocketProtocol>: SelectableChannel, Chan
self.recvAllocator = recvAllocator
// As the socket may already be connected we should ensure we start with the correct addresses cached.
self._addressCache = .init(local: try? socket.localAddress(), remote: try? socket.remoteAddress())
self.lifecycleManager = SocketChannelLifecycleManager(eventLoop: eventLoop, isActiveAtomic: self.isActiveAtomic)
self.lifecycleManager = SocketChannelLifecycleManager(
eventLoop: eventLoop,
isActiveAtomic: self.isActiveAtomic,
supportReconnect: supportReconnect
)
self.socketDescription = socket.description
self.pendingConnect = nil
self._pipeline = ChannelPipeline(channel: self)

View File

@ -20,13 +20,21 @@ class BaseStreamSocketChannel<Socket: SocketProtocol>: BaseSocketChannel<Socket>
private var outputShutdown: Bool = false
private let pendingWrites: PendingStreamWritesManager
override init(socket: Socket,
init(
socket: Socket,
parent: Channel?,
eventLoop: SelectableEventLoop,
recvAllocator: RecvByteBufferAllocator) throws {
recvAllocator: RecvByteBufferAllocator
) throws {
self.pendingWrites = PendingStreamWritesManager(iovecs: eventLoop.iovecs, storageRefs: eventLoop.storageRefs)
self.connectTimeoutScheduled = nil
try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: recvAllocator)
try super.init(
socket: socket,
parent: parent,
eventLoop: eventLoop,
recvAllocator: recvAllocator,
supportReconnect: false
)
}
deinit {

View File

@ -843,7 +843,7 @@ public final class DatagramBootstrap {
func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel {
return try DatagramChannel(eventLoop: eventLoop, socket: socket)
}
return bind0(makeChannel: makeChannel) { (eventLoop, channel) in
return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in
let promise = eventLoop.makePromise(of: Void.self)
channel.registerAlreadyConfigured0(promise: promise)
return promise.futureResult
@ -907,14 +907,61 @@ public final class DatagramBootstrap {
return try DatagramChannel(eventLoop: eventLoop,
protocolFamily: address.protocol)
}
return bind0(makeChannel: makeChannel) { (eventLoop, channel) in
return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in
channel.register().flatMap {
channel.bind(to: address)
}
}
}
private func bind0(makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, _ registerAndBind: @escaping (EventLoop, DatagramChannel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
/// Connect the `DatagramChannel` to `host` and `port`.
///
/// - parameters:
/// - host: The host to connect to.
/// - port: The port to connect to.
public func connect(host: String, port: Int) -> EventLoopFuture<Channel> {
return connect0 {
return try SocketAddress.makeAddressResolvingHost(host, port: port)
}
}
/// Connect the `DatagramChannel` to `address`.
///
/// - parameters:
/// - address: The `SocketAddress` to connect to.
public func connect(to address: SocketAddress) -> EventLoopFuture<Channel> {
return connect0 { address }
}
/// Connect the `DatagramChannel` to a UNIX Domain Socket.
///
/// - parameters:
/// - unixDomainSocketPath: The path of the UNIX Domain Socket to connect to. `path` must not exist, it will be created by the system.
public func connect(unixDomainSocketPath: String) -> EventLoopFuture<Channel> {
return connect0 {
return try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
}
}
private func connect0(_ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture<Channel> {
let address: SocketAddress
do {
address = try makeSocketAddress()
} catch {
return group.next().makeFailedFuture(error)
}
func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel {
return try DatagramChannel(eventLoop: eventLoop,
protocolFamily: address.protocol)
}
return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in
channel.register().flatMap {
channel.connect(to: address)
}
}
}
private func withNewChannel(makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, _ bringup: @escaping (EventLoop, DatagramChannel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
let eventLoop = self.group.next()
let channelInitializer = self.channelInitializer ?? { _ in eventLoop.makeSucceededFuture(()) }
let channelOptions = self._channelOptions
@ -932,7 +979,7 @@ public final class DatagramBootstrap {
channelInitializer(channel)
}.flatMap {
eventLoop.assertInEventLoop()
return registerAndBind(eventLoop, channel)
return bringup(eventLoop, channel)
}.map {
channel
}.flatMapError { error in

View File

@ -17,7 +17,7 @@ import NIOConcurrencyHelpers
private struct PendingDatagramWrite {
var data: ByteBuffer
var promise: Optional<EventLoopPromise<Void>>
let address: SocketAddress
let address: SocketAddress?
var metadata: AddressedEnvelope<ByteBuffer>.Metadata?
/// A helper function that copies the underlying sockaddr structure into temporary storage,
@ -31,7 +31,9 @@ private struct PendingDatagramWrite {
func copySocketAddress(_ target: UnsafeMutablePointer<sockaddr_storage>) -> socklen_t {
let erased = UnsafeMutableRawPointer(target)
switch address {
switch self.address {
case .none:
preconditionFailure("copySocketAddress called on write that has no address")
case .v4(let innerAddress):
erased.storeBytes(of: innerAddress.address, as: sockaddr_in.self)
return socklen_t(MemoryLayout.size(ofValue: innerAddress.address))
@ -99,14 +101,38 @@ private func doPendingDatagramWriteVectorOperation(pending: PendingDatagramWrite
p.data.withUnsafeReadableBytesWithStorageManagement { ptr, storageRef in
storageRefs[c] = storageRef.retain()
let addressLen = p.copySocketAddress(addresses.baseAddress! + c)
/// From man page of `sendmsg(2)`:
///
/// > The `msg_name` field is used on an unconnected socket to specify
/// > the target address for a datagram. It points to a buffer
/// > containing the address; the `msg_namelen` field should be set to
/// > the size of the address. For a connected socket, these fields
/// > should be specified as `NULL` and 0, respectively.
let address: UnsafeMutablePointer<sockaddr_storage>?
let addressLen: socklen_t
let protocolFamily: NIOBSDSocket.ProtocolFamily
if let envelopeAddress = p.address {
precondition(pending.remoteAddress == nil, "Pending write with address on connected socket.")
address = addresses.baseAddress! + c
addressLen = p.copySocketAddress(address!)
protocolFamily = envelopeAddress.protocol
} else {
guard let connectedRemoteAddress = pending.remoteAddress else {
preconditionFailure("Pending write without address on unconnected socket.")
}
address = nil
addressLen = 0
protocolFamily = connectedRemoteAddress.protocol
}
iovecs[c] = iovec(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer))
var controlBytes = UnsafeOutboundControlBytes(controlBytes: controlMessageStorage[c])
controlBytes.appendExplicitCongestionState(metadata: p.metadata, protocolFamily: p.address.protocol)
controlBytes.appendExplicitCongestionState(metadata: p.metadata, protocolFamily: protocolFamily)
let controlMessageBytePointer = controlBytes.validControlBytes
let msg = msghdr(msg_name: addresses.baseAddress! + c,
let msg = msghdr(msg_name: address,
msg_namelen: addressLen,
msg_iov: iovecs.baseAddress! + c,
msg_iovlen: 1,
@ -140,6 +166,7 @@ private struct PendingDatagramWritesState {
private var pendingWrites = MarkedCircularBuffer<PendingDatagramWrite>(initialCapacity: 16)
private var chunks: Int = 0
public private(set) var bytes: Int64 = 0
private(set) var remoteAddress: SocketAddress? = nil
public var nextWrite: PendingDatagramWrite? {
return self.pendingWrites.first
@ -194,6 +221,10 @@ private struct PendingDatagramWritesState {
self.pendingWrites.mark()
}
mutating func markConnected(to remoteAddress: SocketAddress) {
self.remoteAddress = remoteAddress
}
/// Indicate that a write has happened, this may be a write of multiple outstanding writes (using for example `sendmmsg`).
///
/// - warning: The closure will simply fulfill all the promises in order. If one of those promises does for example close the `Channel` we might see subsequent writes fail out of order. Example: Imagine the user issues three writes: `A`, `B` and `C`. Imagine that `A` and `B` both get successfully written in one write operation but the user closes the `Channel` in `A`'s callback. Then overall the promises will be fulfilled in this order: 1) `A`: success 2) `C`: error 3) `B`: success. Note how `B` and `C` get fulfilled out of order.
@ -402,6 +433,11 @@ final class PendingDatagramWritesManager: PendingWritesManager {
self.state.markFlushCheckpoint()
}
/// Mark that the socket is connected.
func markConnected(to remoteAddress: SocketAddress) {
self.state.markConnected(to: remoteAddress)
}
/// Is there a flush pending?
var isFlushPending: Bool {
return self.state.isFlushPending
@ -412,18 +448,9 @@ final class PendingDatagramWritesManager: PendingWritesManager {
return self.state.isEmpty
}
/// Add a pending write.
///
/// - parameters:
/// - envelope: The `AddressedEnvelope<IOData>` to write.
/// - promise: Optionally an `EventLoopPromise` that will get the write operation's result
/// - result: If the `Channel` is still writable after adding the write of `data`.
func add(envelope: AddressedEnvelope<ByteBuffer>, promise: EventLoopPromise<Void>?) -> Bool {
private func add(_ pendingWrite: PendingDatagramWrite) -> Bool {
assert(self.isOpen)
self.state.append(.init(data: envelope.data,
promise: promise,
address: envelope.remoteAddress,
metadata: envelope.metadata))
self.state.append(pendingWrite)
if self.state.bytes > waterMark.high && channelWritabilityFlag.compareAndExchange(expected: true, desired: false) {
// Returns false to signal the Channel became non-writable and we need to notify the user.
@ -433,6 +460,48 @@ final class PendingDatagramWritesManager: PendingWritesManager {
return true
}
/// Add a pending write, with an `AddressedEnvelope`, usually on an unconnected socket.
///
/// - parameters:
/// - envelope: The `AddressedEnvelope<ByteBuffer>` to write.
/// - promise: Optionally an `EventLoopPromise` that will get the write operation's result
/// - returns: If the `Channel` is still writable after adding the write of `data`.
///
/// - warning: If the socket is connected, then the `envelope.remoteAddress` _must_ match the
/// address of the connected peer, otherwise this function will throw a fatal error.
func add(envelope: AddressedEnvelope<ByteBuffer>, promise: EventLoopPromise<Void>?) -> Bool {
if let remoteAddress = self.state.remoteAddress {
precondition(envelope.remoteAddress == remoteAddress, """
Remote address of AddressedEnvelope does not match remote address of connected socket.
""")
return self.add(PendingDatagramWrite(
data: envelope.data,
promise: promise,
address: nil,
metadata: envelope.metadata))
} else {
return self.add(PendingDatagramWrite(
data: envelope.data,
promise: promise,
address: envelope.remoteAddress,
metadata: envelope.metadata))
}
}
/// Add a pending write, without an `AddressedEnvelope`, on a connected socket.
///
/// - parameters:
/// - data: The `ByteBuffer` to write.
/// - promise: Optionally an `EventLoopPromise` that will get the write operation's result
/// - returns: If the `Channel` is still writable after adding the write of `data`.
func add(data: ByteBuffer, promise: EventLoopPromise<Void>?) -> Bool {
return self.add(PendingDatagramWrite(
data: data,
promise: promise,
address: nil,
metadata: nil))
}
/// Returns the best mechanism to write pending data at the current point in time.
var currentBestWriteMechanism: WriteMechanism {
return self.state.currentBestWriteMechanism
@ -442,10 +511,10 @@ final class PendingDatagramWritesManager: PendingWritesManager {
/// On platforms that do not support a gathering write operation,
///
/// - parameters:
/// - scalarWriteOperation: An operation that writes a single, contiguous array of bytes (usually `sendto`).
/// - scalarWriteOperation: An operation that writes a single, contiguous array of bytes (usually `sendmsg`).
/// - vectorWriteOperation: An operation that writes multiple contiguous arrays of bytes (usually `sendmmsg`).
/// - returns: The `WriteResult` and whether the `Channel` is now writable.
func triggerAppropriateWriteOperations(scalarWriteOperation: (UnsafeRawBufferPointer, UnsafePointer<sockaddr>, socklen_t, AddressedEnvelope<ByteBuffer>.Metadata?) throws -> IOResult<Int>,
func triggerAppropriateWriteOperations(scalarWriteOperation: (UnsafeRawBufferPointer, UnsafePointer<sockaddr>?, socklen_t, AddressedEnvelope<ByteBuffer>.Metadata?) throws -> IOResult<Int>,
vectorWriteOperation: (UnsafeMutableBufferPointer<MMsgHdr>) throws -> IOResult<Int>) throws -> OverallWriteResult {
return try self.triggerWriteOperations { writeMechanism in
switch writeMechanism {
@ -515,16 +584,33 @@ final class PendingDatagramWritesManager: PendingWritesManager {
///
/// - parameters:
/// - scalarWriteOperation: An operation that writes a single, contiguous array of bytes (usually `sendto`).
private func triggerScalarBufferWrite(scalarWriteOperation: (UnsafeRawBufferPointer, UnsafePointer<sockaddr>, socklen_t, AddressedEnvelope<ByteBuffer>.Metadata?) throws -> IOResult<Int>) rethrows -> OneWriteOperationResult {
private func triggerScalarBufferWrite(scalarWriteOperation: (UnsafeRawBufferPointer, UnsafePointer<sockaddr>?, socklen_t, AddressedEnvelope<ByteBuffer>.Metadata?) throws -> IOResult<Int>) rethrows -> OneWriteOperationResult {
assert(self.state.isFlushPending && self.isOpen && !self.state.isEmpty,
"illegal state for scalar datagram write operation: flushPending: \(self.state.isFlushPending), isOpen: \(self.isOpen), empty: \(self.state.isEmpty)")
let pending = self.state.nextWrite!
do {
let writeResult = try pending.address.withSockAddr { (addrPtr, addrSize) in
let writeResult: IOResult<Int>
if let address = pending.address {
assert(self.state.remoteAddress == nil, "Pending write with address on connected socket.")
writeResult = try address.withSockAddr { (addrPtr, addrSize) in
try pending.data.withUnsafeReadableBytes {
try scalarWriteOperation($0, addrPtr, socklen_t(addrSize), pending.metadata)
}
}
} else {
/// From man page of `sendmsg(2)`:
///
/// > The `msg_name` field is used on an unconnected socket to specify
/// > the target address for a datagram. It points to a buffer
/// > containing the address; the `msg_namelen` field should be set to
/// > the size of the address. For a connected socket, these fields
/// > should be specified as `NULL` and 0, respectively.
assert(self.state.remoteAddress != nil, "Pending write without address on unconnected socket.")
writeResult = try pending.data.withUnsafeReadableBytes {
try scalarWriteOperation($0, nil, 0, pending.metadata)
}
}
return self.didWrite(writeResult, messages: nil)
} catch {
return try self.handleError(error)

View File

@ -98,7 +98,7 @@ final class PipePair: SocketProtocol {
}
func sendmsg(pointer: UnsafeRawBufferPointer,
destinationPtr: UnsafePointer<sockaddr>,
destinationPtr: UnsafePointer<sockaddr>?,
destinationSize: socklen_t,
controlBytes: UnsafeMutableRawBufferPointer) throws -> IOResult<Int> {
throw ChannelError.operationUnsupported

View File

@ -151,7 +151,7 @@ typealias IOVector = iovec
/// (because the socket is in non-blocking mode).
/// - throws: An `IOError` if the operation failed.
func sendmsg(pointer: UnsafeRawBufferPointer,
destinationPtr: UnsafePointer<sockaddr>,
destinationPtr: UnsafePointer<sockaddr>?,
destinationSize: socklen_t,
controlBytes: UnsafeMutableRawBufferPointer) throws -> IOResult<Int> {
// Dubious const casts - it should be OK as there is no reason why this should get mutated

View File

@ -155,10 +155,13 @@ final class ServerSocketChannel: BaseSocketChannel<ServerSocket> {
init(serverSocket: ServerSocket, eventLoop: SelectableEventLoop, group: EventLoopGroup) throws {
self.group = group
try super.init(socket: serverSocket,
try super.init(
socket: serverSocket,
parent: nil,
eventLoop: eventLoop,
recvAllocator: AdaptiveRecvByteBufferAllocator())
recvAllocator: AdaptiveRecvByteBufferAllocator(),
supportReconnect: false
)
}
convenience init(socket: NIOBSDSocket.Handle, eventLoop: SelectableEventLoop, group: EventLoopGroup) throws {
@ -398,10 +401,13 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
storageRefs: eventLoop.storageRefs,
controlMessageStorage: eventLoop.controlMessageStorage)
try super.init(socket: socket,
try super.init(
socket: socket,
parent: nil,
eventLoop: eventLoop,
recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048))
recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048),
supportReconnect: true
)
}
init(socket: Socket, parent: Channel? = nil, eventLoop: SelectableEventLoop) throws {
@ -412,7 +418,13 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
addresses: eventLoop.addresses,
storageRefs: eventLoop.storageRefs,
controlMessageStorage: eventLoop.controlMessageStorage)
try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048))
try super.init(
socket: socket,
parent: parent,
eventLoop: eventLoop,
recvAllocator: FixedSizeRecvByteBufferAllocator(capacity: 2048),
supportReconnect: true
)
}
// MARK: Datagram Channel overrides required by BaseSocketChannel
@ -526,12 +538,24 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
}
override func connectSocket(to address: SocketAddress) throws -> Bool {
// For now we don't support operating in connected mode for datagram channels.
throw ChannelError.operationUnsupported
// TODO: this could be a channel option to do other things instead here, e.g. fail the connect
if !self.pendingWrites.isEmpty {
self.pendingWrites.failAll(
error: IOError(
errnoCode: EISCONN,
reason: "Socket was connected before flushing pending write."),
close: false)
}
if try self.socket.connect(to: address) {
self.pendingWrites.markConnected(to: address)
return true
} else {
preconditionFailure("Connect of datagram socket did not complete synchronously.")
}
}
override func finishConnectSocket() throws {
// For now we don't support operating in connected mode for datagram channels.
// This is not required for connected datagram channels connect is a synchronous operation.
throw ChannelError.operationUnsupported
}
@ -668,11 +692,52 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
return true
}
}
/// Buffer a write in preparation for a flush.
override func bufferPendingWrite(data: NIOAny, promise: EventLoopPromise<Void>?) {
let data = self.unwrapData(data, as: AddressedEnvelope<ByteBuffer>.self)
if !self.pendingWrites.add(envelope: data, promise: promise) {
/// Buffer a write in preparation for a flush.
///
/// When the channel is unconnected, `data` _must_ be of type `AddressedEnvelope<ByteBuffer>`.
///
/// When the channel is connected, `data` _should_ be of type `ByteBuffer`, but _may_ be of type
/// `AddressedEnvelope<ByteBuffer>` to allow users to provide protocol control messages via
/// `AddressedEnvelope.metadata`. In this case, `AddressedEnvelope.remoteAddress` _must_ match
/// the address of the connected peer.
override func bufferPendingWrite(data: NIOAny, promise: EventLoopPromise<Void>?) {
if let envelope = self.tryUnwrapData(data, as: AddressedEnvelope<ByteBuffer>.self) {
return bufferPendingAddressedWrite(envelope: envelope, promise: promise)
}
// If it's not an `AddressedEnvelope` then it must be a `ByteBuffer` so we let the common
// `unwrapData(_:as:)` throw the fatal error if it's some other type.
let data = self.unwrapData(data, as: ByteBuffer.self)
return bufferPendingUnaddressedWrite(data: data, promise: promise)
}
/// Buffer a write in preparation for a flush.
private func bufferPendingUnaddressedWrite(data: ByteBuffer, promise: EventLoopPromise<Void>?) {
// It is only appropriate to not use an AddressedEnvelope if the socket is connected.
guard self.remoteAddress != nil else {
promise?.fail(DatagramChannelError.WriteOnUnconnectedSocketWithoutAddress())
return
}
if !self.pendingWrites.add(data: data, promise: promise) {
assert(self.isActive)
self.pipeline.syncOperations.fireChannelWritabilityChanged()
}
}
/// Buffer a write in preparation for a flush.
private func bufferPendingAddressedWrite(envelope: AddressedEnvelope<ByteBuffer>, promise: EventLoopPromise<Void>?) {
// If the socket is connected, check the remote provided matches the connected address.
if let connectedRemoteAddress = self.remoteAddress {
guard envelope.remoteAddress == connectedRemoteAddress else {
promise?.fail(DatagramChannelError.WriteOnConnectedSocketWithInvalidAddress(
envelopeRemoteAddress: envelope.remoteAddress,
connectedRemoteAddress: connectedRemoteAddress))
return
}
}
if !self.pendingWrites.add(envelope: envelope, promise: promise) {
assert(self.isActive)
self.pipeline.syncOperations.fireChannelWritabilityChanged()
}

View File

@ -54,7 +54,7 @@ protocol SocketProtocol: BaseSocketProtocol {
controlBytes: inout UnsafeReceivedControlBytes) throws -> IOResult<Int>
func sendmsg(pointer: UnsafeRawBufferPointer,
destinationPtr: UnsafePointer<sockaddr>,
destinationPtr: UnsafePointer<sockaddr>?,
destinationSize: socklen_t,
controlBytes: UnsafeMutableRawBufferPointer) throws -> IOResult<Int>

View File

@ -73,7 +73,16 @@ private final class EchoHandler: ChannelInboundHandler {
}
// First argument is the program path
let arguments = CommandLine.arguments
var arguments = CommandLine.arguments
// Support for `--connect` if it appears as the first argument.
let connectedMode: Bool
if let connectedModeFlagIndex = arguments.firstIndex(where: { $0 == "--connect" }) {
connectedMode = true
arguments.remove(at: connectedModeFlagIndex)
} else {
connectedMode = false
}
// Now process the positional arguments.
let arg1 = arguments.dropFirst().first
let arg2 = arguments.dropFirst(2).first
let arg3 = arguments.dropFirst(3).first
@ -133,7 +142,13 @@ let channel = try { () -> Channel in
case .unixDomainSocket(_, let listeningPath):
return try bootstrap.bind(unixDomainSocketPath: listeningPath).wait()
}
}()
}()
if connectedMode {
let remoteAddress = try remoteAddress()
print("Connecting to remote: \(remoteAddress)")
try channel.connect(to: remoteAddress).wait()
}
// Will be closed after we echo-ed back to the server.
try channel.closeFuture.wait()

View File

@ -2,7 +2,7 @@
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2018-2021 Apple Inc. and the SwiftNIO project authors
// Copyright (c) 2018-2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
@ -29,7 +29,6 @@ extension DatagramChannelTests {
return [
("testBasicChannelCommunication", testBasicChannelCommunication),
("testManyWrites", testManyWrites),
("testConnectionFails", testConnectionFails),
("testDatagramChannelHasWatermark", testDatagramChannelHasWatermark),
("testWriteFuturesFailWhenChannelClosed", testWriteFuturesFailWhenChannelClosed),
("testManyManyDatagramWrites", testManyManyDatagramWrites),
@ -67,6 +66,14 @@ extension DatagramChannelTests {
("testReceiveEcnAndPacketInfoIPV6VectorRead", testReceiveEcnAndPacketInfoIPV6VectorRead),
("testReceiveEcnAndPacketInfoIPV4VectorReadVectorWrite", testReceiveEcnAndPacketInfoIPV4VectorReadVectorWrite),
("testReceiveEcnAndPacketInfoIPV6VectorReadVectorWrite", testReceiveEcnAndPacketInfoIPV6VectorReadVectorWrite),
("testSendingAddressedEnvelopeOnUnconnectedSocketSucceeds", testSendingAddressedEnvelopeOnUnconnectedSocketSucceeds),
("testSendingByteBufferOnUnconnectedSocketFails", testSendingByteBufferOnUnconnectedSocketFails),
("testSendingByteBufferOnConnectedSocketSucceeds", testSendingByteBufferOnConnectedSocketSucceeds),
("testSendingAddressedEnvelopeOnConnectedSocketSucceeds", testSendingAddressedEnvelopeOnConnectedSocketSucceeds),
("testSendingAddressedEnvelopeOnConnectedSocketWithDifferentAddressFails", testSendingAddressedEnvelopeOnConnectedSocketWithDifferentAddressFails),
("testConnectingSocketAfterFlushingExistingMessages", testConnectingSocketAfterFlushingExistingMessages),
("testConnectingSocketFailsBufferedWrites", testConnectingSocketFailsBufferedWrites),
("testReconnectingSocketFailsBufferedWrites", testReconnectingSocketFailsBufferedWrites),
]
}
}

View File

@ -102,10 +102,11 @@ private class DatagramReadRecorder<DataType>: ChannelInboundHandler {
}
}
final class DatagramChannelTests: XCTestCase {
class DatagramChannelTests: XCTestCase {
private var group: MultiThreadedEventLoopGroup! = nil
private var firstChannel: Channel! = nil
private var secondChannel: Channel! = nil
private var thirdChannel: Channel! = nil
private func buildChannel(group: EventLoopGroup, host: String = "127.0.0.1") throws -> Channel {
return try DatagramBootstrap(group: group)
@ -128,9 +129,11 @@ final class DatagramChannelTests: XCTestCase {
override func setUp() {
super.setUp()
self.continueAfterFailure = false
self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
self.firstChannel = try! buildChannel(group: group)
self.secondChannel = try! buildChannel(group: group)
self.thirdChannel = try! buildChannel(group: group)
}
override func tearDown() {
@ -173,12 +176,6 @@ final class DatagramChannelTests: XCTestCase {
}
}
func testConnectionFails() throws {
XCTAssertThrowsError(try self.firstChannel.connect(to: self.secondChannel.localAddress!).wait()) { error in
XCTAssertEqual(.operationUnsupported, error as? ChannelError)
}
}
func testDatagramChannelHasWatermark() throws {
_ = try self.firstChannel.setOption(ChannelOptions.writeBufferWaterMark, value: ChannelOptions.Types.WriteBufferWaterMark(low: 1, high: 1024)).wait()
@ -916,4 +913,214 @@ final class DatagramChannelTests: XCTestCase {
}
testEcnAndPacketInfoReceive(address: "::1", vectorRead: true, vectorSend: true, receivePacketInfo: true)
}
func assertSending(
data: ByteBuffer,
from sourceChannel: Channel,
to destinationChannel: Channel,
wrappingInAddressedEnvelope shouldWrapInAddressedEnvelope: Bool,
resultsIn expectedResult: Result<Void, Error>,
file: StaticString = #file,
line: UInt = #line
) throws {
// Wrap data in AddressedEnvelope if required.
let writePayload: NIOAny
if shouldWrapInAddressedEnvelope {
let envelope = AddressedEnvelope(remoteAddress: destinationChannel.localAddress!, data: data)
writePayload = NIOAny(envelope)
} else {
writePayload = NIOAny(data)
}
// Write and flush.
let writeResult = sourceChannel.writeAndFlush(writePayload)
// Check the expected result.
switch expectedResult {
case .success:
// Check the write succeeded.
XCTAssertNoThrow(try writeResult.wait())
// Check the destination received the sent payload.
let reads = try destinationChannel.waitForDatagrams(count: 1)
XCTAssertEqual(reads.count, 1)
let read = reads.first!
XCTAssertEqual(read.data, data)
XCTAssertEqual(read.remoteAddress, sourceChannel.localAddress!)
case .failure(let expectedError):
// Check the error is of the expected type.
XCTAssertThrowsError(try writeResult.wait()) { error in
guard type(of: error) == type(of: expectedError) else {
XCTFail("expected error of type \(type(of: expectedError)), but caught other error of type (\(type(of: error)): \(error)")
return
}
}
}
}
func assertSendingHelloWorld(
from sourceChannel: Channel,
to destinationChannel: Channel,
wrappingInAddressedEnvelope shouldWrapInAddressedEnvelope: Bool,
resultsIn expectedResult: Result<Void, Error>,
file: StaticString = #file,
line: UInt = #line
) throws {
try self.assertSending(
data: sourceChannel.allocator.buffer(staticString: "hello, world!"),
from: sourceChannel,
to: destinationChannel,
wrappingInAddressedEnvelope: shouldWrapInAddressedEnvelope,
resultsIn: expectedResult,
file: file,
line: line
)
}
func bufferWrite(
of data: ByteBuffer,
from sourceChannel: Channel,
to destinationChannel: Channel,
wrappingInAddressedEnvelope shouldWrapInAddressedEnvelope: Bool
) -> EventLoopFuture<Void> {
if shouldWrapInAddressedEnvelope {
let envelope = AddressedEnvelope(remoteAddress: destinationChannel.localAddress!, data: data)
return sourceChannel.write(envelope)
} else {
return sourceChannel.write(data)
}
}
func bufferWriteOfHelloWorld(
from sourceChannel: Channel,
to destinationChannel: Channel,
wrappingInAddressedEnvelope shouldWrapInAddressedEnvelope: Bool
) -> EventLoopFuture<Void> {
self.bufferWrite(
of: sourceChannel.allocator.buffer(staticString: "hello, world!"),
from: sourceChannel,
to: destinationChannel,
wrappingInAddressedEnvelope: shouldWrapInAddressedEnvelope
)
}
func testSendingAddressedEnvelopeOnUnconnectedSocketSucceeds() throws {
try self.assertSendingHelloWorld(
from: self.firstChannel,
to: self.secondChannel,
wrappingInAddressedEnvelope: true,
resultsIn: .success(())
)
}
func testSendingByteBufferOnUnconnectedSocketFails() throws {
try self.assertSendingHelloWorld(
from: self.firstChannel,
to: self.secondChannel,
wrappingInAddressedEnvelope: false,
resultsIn: .failure(DatagramChannelError.WriteOnUnconnectedSocketWithoutAddress())
)
}
func testSendingByteBufferOnConnectedSocketSucceeds() throws {
XCTAssertNoThrow(try self.firstChannel.connect(to: self.secondChannel.localAddress!).wait())
try self.assertSendingHelloWorld(
from: self.firstChannel,
to: self.secondChannel,
wrappingInAddressedEnvelope: false,
resultsIn: .success(())
)
}
func testSendingAddressedEnvelopeOnConnectedSocketSucceeds() throws {
XCTAssertNoThrow(try self.firstChannel.connect(to: self.secondChannel.localAddress!).wait())
try self.assertSendingHelloWorld(
from: self.firstChannel,
to: self.secondChannel,
wrappingInAddressedEnvelope: true,
resultsIn: .success(())
)
}
func testSendingAddressedEnvelopeOnConnectedSocketWithDifferentAddressFails() throws {
XCTAssertNoThrow(try self.firstChannel.connect(to: self.secondChannel.localAddress!).wait())
try self.assertSendingHelloWorld(
from: self.firstChannel,
to: self.thirdChannel,
wrappingInAddressedEnvelope: true,
resultsIn: .failure(DatagramChannelError.WriteOnConnectedSocketWithInvalidAddress(
envelopeRemoteAddress: self.thirdChannel.localAddress!,
connectedRemoteAddress: self.secondChannel.localAddress!))
)
}
func testConnectingSocketAfterFlushingExistingMessages() throws {
// Send message from firstChannel to secondChannel.
try self.assertSendingHelloWorld(
from: self.firstChannel,
to: self.secondChannel,
wrappingInAddressedEnvelope: true,
resultsIn: .success(())
)
// Connect firstChannel to thirdChannel.
XCTAssertNoThrow(try self.firstChannel.connect(to: self.thirdChannel.localAddress!).wait())
// Send message from firstChannel to thirdChannel.
try self.assertSendingHelloWorld(
from: self.firstChannel,
to: self.thirdChannel,
wrappingInAddressedEnvelope: false,
resultsIn: .success(())
)
}
func testConnectingSocketFailsBufferedWrites() throws {
// Buffer message from firstChannel to secondChannel.
let bufferedWrite = bufferWriteOfHelloWorld(from: self.firstChannel, to: self.secondChannel, wrappingInAddressedEnvelope: true)
// Connect firstChannel to thirdChannel.
XCTAssertNoThrow(try self.firstChannel.connect(to: self.thirdChannel.localAddress!).wait())
// Check that the buffered write was failed.
XCTAssertThrowsError(try bufferedWrite.wait()) { error in
XCTAssertEqual((error as? IOError)?.errnoCode, EISCONN, "expected EISCONN, but caught other error: \(error)")
}
// Send message from firstChannel to thirdChannel.
try self.assertSendingHelloWorld(
from: self.firstChannel,
to: self.thirdChannel,
wrappingInAddressedEnvelope: false,
resultsIn: .success(())
)
}
func testReconnectingSocketFailsBufferedWrites() throws {
// Connect firstChannel to secondChannel.
XCTAssertNoThrow(try self.firstChannel.connect(to: self.secondChannel.localAddress!).wait())
// Buffer message from firstChannel to secondChannel.
let bufferedWrite = bufferWriteOfHelloWorld(from: self.firstChannel, to: self.secondChannel, wrappingInAddressedEnvelope: false)
// Connect firstChannel to thirdChannel.
XCTAssertNoThrow(try self.firstChannel.connect(to: self.thirdChannel.localAddress!).wait())
// Check that the buffered write was failed.
XCTAssertThrowsError(try bufferedWrite.wait()) { error in
XCTAssertEqual((error as? IOError)?.errnoCode, EISCONN, "expected EISCONN, but caught other error: \(error)")
}
// Send message from firstChannel to thirdChannel.
try self.assertSendingHelloWorld(
from: self.firstChannel,
to: self.thirdChannel,
wrappingInAddressedEnvelope: false,
resultsIn: .success(())
)
}
}

View File

@ -126,7 +126,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
if expected.count > singleState {
XCTAssertGreaterThan(returns.count, everythingState)
XCTAssertEqual(expected[singleState].0, buf.count, "in single write \(singleState) (overall \(everythingState)), \(expected[singleState].0) bytes expected but \(buf.count) actual", file: (file), line: line)
XCTAssertEqual(expected[singleState].1, SocketAddress(addr), "in single write \(singleState) (overall \(everythingState)), \(expected[singleState].1) address expected but \(SocketAddress(addr)) received", file: (file), line: line)
XCTAssertEqual(expected[singleState].1, addr.map(SocketAddress.init), "in single write \(singleState) (overall \(everythingState)), \(expected[singleState].1) address expected but \(String(describing: addr.map(SocketAddress.init))) received", file: (file), line: line)
XCTAssertEqual(expected[singleState].1.expectedSize, len, "in single write \(singleState) (overall \(everythingState)), \(expected[singleState].1.expectedSize) socklen expected but \(len) received", file: (file), line: line)
switch returns[everythingState] {