NIOThrowingAsyncSequenceProducer throws when cancelled (#2415)
* NIOThrowingAsyncSequenceProducer throws when cancelled * PR review
This commit is contained in:
parent
5f8b0647e4
commit
d1690f8541
|
@ -511,6 +511,21 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
return nil
|
||||
}
|
||||
|
||||
case .returnCancellationError:
|
||||
self._lock.unlock()
|
||||
// We have deprecated the generic Failure type in the public API and Failure should
|
||||
// now be `Swift.Error`. However, if users have not migrated to the new API they could
|
||||
// still use a custom generic Error type and this cast might fail.
|
||||
// In addition, we use `NIOThrowingAsyncSequenceProducer` in the implementation of the
|
||||
// non-throwing variant `NIOAsyncSequenceProducer` where `Failure` will be `Never` and
|
||||
// this cast will fail as well.
|
||||
// Everything is marked @inlinable and the Failure type is known at compile time,
|
||||
// therefore this cast should be optimised away in release build.
|
||||
if let error = CancellationError() as? Failure {
|
||||
throw error
|
||||
}
|
||||
return nil
|
||||
|
||||
case .returnNil:
|
||||
self._lock.unlock()
|
||||
return nil
|
||||
|
@ -603,6 +618,9 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
failure: Failure?
|
||||
)
|
||||
|
||||
/// The state once a call to next has been cancelled. Cancel the source when entering this state.
|
||||
case cancelled(iteratorInitialized: Bool)
|
||||
|
||||
/// The state once there can be no outstanding demand. This can happen if:
|
||||
/// 1. The ``NIOThrowingAsyncSequenceProducer/AsyncIterator`` was deinited
|
||||
/// 2. The underlying source finished and all buffered elements have been consumed
|
||||
|
@ -644,7 +662,8 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
switch self._state {
|
||||
case .initial(_, iteratorInitialized: false),
|
||||
.streaming(_, _, _, _, iteratorInitialized: false),
|
||||
.sourceFinished(_, iteratorInitialized: false, _):
|
||||
.sourceFinished(_, iteratorInitialized: false, _),
|
||||
.cancelled(iteratorInitialized: false):
|
||||
// No iterator was created so we can transition to finished right away.
|
||||
self._state = .finished(iteratorInitialized: false)
|
||||
|
||||
|
@ -652,7 +671,8 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
|
||||
case .initial(_, iteratorInitialized: true),
|
||||
.streaming(_, _, _, _, iteratorInitialized: true),
|
||||
.sourceFinished(_, iteratorInitialized: true, _):
|
||||
.sourceFinished(_, iteratorInitialized: true, _),
|
||||
.cancelled(iteratorInitialized: true):
|
||||
// An iterator was created and we deinited the sequence.
|
||||
// This is an expected pattern and we just continue on normal.
|
||||
return .none
|
||||
|
@ -673,6 +693,7 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
case .initial(_, iteratorInitialized: true),
|
||||
.streaming(_, _, _, _, iteratorInitialized: true),
|
||||
.sourceFinished(_, iteratorInitialized: true, _),
|
||||
.cancelled(iteratorInitialized: true),
|
||||
.finished(iteratorInitialized: true):
|
||||
// Our sequence is a unicast sequence and does not support multiple AsyncIterator's
|
||||
fatalError("NIOThrowingAsyncSequenceProducer allows only a single AsyncIterator to be created")
|
||||
|
@ -694,6 +715,10 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
iteratorInitialized: true
|
||||
)
|
||||
|
||||
case .cancelled(iteratorInitialized: false):
|
||||
// An iterator needs to be initialized before we can be cancelled.
|
||||
preconditionFailure("Internal inconsistency")
|
||||
|
||||
case .sourceFinished(let buffer, false, let failure):
|
||||
// The first and only iterator was initialized.
|
||||
self._state = .sourceFinished(
|
||||
|
@ -727,13 +752,15 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
switch self._state {
|
||||
case .initial(_, iteratorInitialized: false),
|
||||
.streaming(_, _, _, _, iteratorInitialized: false),
|
||||
.sourceFinished(_, iteratorInitialized: false, _):
|
||||
.sourceFinished(_, iteratorInitialized: false, _),
|
||||
.cancelled(iteratorInitialized: false):
|
||||
// An iterator needs to be initialized before it can be deinitialized.
|
||||
preconditionFailure("Internal inconsistency")
|
||||
|
||||
case .initial(_, iteratorInitialized: true),
|
||||
.streaming(_, _, _, _, iteratorInitialized: true),
|
||||
.sourceFinished(_, iteratorInitialized: true, _):
|
||||
.sourceFinished(_, iteratorInitialized: true, _),
|
||||
.cancelled(iteratorInitialized: true):
|
||||
// An iterator was created and deinited. Since we only support
|
||||
// a single iterator we can now transition to finish and inform the delegate.
|
||||
self._state = .finished(iteratorInitialized: true)
|
||||
|
@ -861,7 +888,7 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
|
||||
return .init(shouldProduceMore: shouldProduceMore)
|
||||
|
||||
case .sourceFinished, .finished:
|
||||
case .cancelled, .sourceFinished, .finished:
|
||||
// If the source has finished we are dropping the elements.
|
||||
return .returnDropped
|
||||
|
||||
|
@ -913,7 +940,7 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
|
||||
return .none
|
||||
|
||||
case .sourceFinished, .finished:
|
||||
case .cancelled, .sourceFinished, .finished:
|
||||
// If the source has finished, finishing again has no effect.
|
||||
return .none
|
||||
|
||||
|
@ -968,11 +995,14 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
return .resumeContinuationWithCancellationErrorAndCallDidTerminate(continuation)
|
||||
|
||||
case .streaming(_, _, continuation: .none, _, let iteratorInitialized):
|
||||
self._state = .finished(iteratorInitialized: iteratorInitialized)
|
||||
// We may have elements in the buffer, which is why we have no continuation
|
||||
// waiting. We must store the cancellation error to hand it out on the next
|
||||
// next() call.
|
||||
self._state = .cancelled(iteratorInitialized: iteratorInitialized)
|
||||
|
||||
return .callDidTerminate
|
||||
|
||||
case .sourceFinished, .finished:
|
||||
case .cancelled, .sourceFinished, .finished:
|
||||
// If the source has finished, finishing again has no effect.
|
||||
return .none
|
||||
|
||||
|
@ -992,6 +1022,8 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
/// Indicates that the `Failure` should be returned to the caller and
|
||||
/// that ``NIOAsyncSequenceProducerDelegate/didTerminate()`` should be called.
|
||||
case returnFailureAndCallDidTerminate(Failure?)
|
||||
/// Indicates that the next call to AsyncSequence got cancelled
|
||||
case returnCancellationError
|
||||
/// Indicates that the `nil` should be returned to the caller.
|
||||
case returnNil
|
||||
/// Indicates that the `Task` of the caller should be suspended.
|
||||
|
@ -1075,6 +1107,10 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
return .returnFailureAndCallDidTerminate(failure)
|
||||
}
|
||||
|
||||
case .cancelled(let iteratorInitialized):
|
||||
self._state = .finished(iteratorInitialized: iteratorInitialized)
|
||||
return .returnCancellationError
|
||||
|
||||
case .finished:
|
||||
return .returnNil
|
||||
|
||||
|
@ -1119,7 +1155,7 @@ extension NIOThrowingAsyncSequenceProducer {
|
|||
return .none
|
||||
}
|
||||
|
||||
case .streaming(_, _, .some(_), _, _), .sourceFinished, .finished:
|
||||
case .streaming(_, _, .some(_), _, _), .sourceFinished, .finished, .cancelled:
|
||||
preconditionFailure("This should have already been handled by `next()`")
|
||||
|
||||
case .modifying:
|
||||
|
|
|
@ -743,6 +743,36 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase {
|
|||
|
||||
XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate])
|
||||
}
|
||||
|
||||
func testIteratorThrows_whenCancelled() async {
|
||||
_ = self.source.yield(contentsOf: Array(0..<100))
|
||||
await withThrowingTaskGroup(of: Void.self) { group in
|
||||
group.addTask {
|
||||
var counter = 0
|
||||
guard let sequence = self.sequence else {
|
||||
return XCTFail("Expected to have an AsyncSequence")
|
||||
}
|
||||
|
||||
do {
|
||||
for try await next in sequence {
|
||||
XCTAssertEqual(next, counter)
|
||||
counter += 1
|
||||
}
|
||||
XCTFail("Expected that this throws")
|
||||
} catch is CancellationError {
|
||||
// expected
|
||||
} catch {
|
||||
XCTFail("Unexpected error: \(error)")
|
||||
}
|
||||
|
||||
XCTAssertLessThan(counter, 100)
|
||||
}
|
||||
|
||||
group.cancelAll()
|
||||
}
|
||||
|
||||
XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate])
|
||||
}
|
||||
}
|
||||
|
||||
// This is needed until async let is supported to be used in autoclosures
|
||||
|
|
Loading…
Reference in New Issue